From 228389882eeb113a56720a97716b01d33d70034b Mon Sep 17 00:00:00 2001 From: Jamil Date: Mon, 25 Mar 2024 15:51:35 -0700 Subject: [PATCH] refactor(connlib): delay initialization of `Sockets` until we have a tokio runtime (#4286) Our sockets need to be initialized within a tokio runtime context. To achieve this, we don't actually initialize anything on `Sockets::new`. Instead, we call `rebind` within the constructor of `Tunnel` which already runs in a tokio context. Fixes: #4282 --------- Signed-off-by: Jamil Co-authored-by: Thomas Eizinger Co-authored-by: Reactor Scram --- rust/connlib/clients/android/src/lib.rs | 43 +++--------- rust/connlib/clients/apple/src/lib.rs | 2 +- rust/connlib/clients/shared/src/eventloop.rs | 10 +-- rust/connlib/clients/shared/src/lib.rs | 12 ++-- rust/connlib/tunnel/src/io.rs | 15 ++-- rust/connlib/tunnel/src/lib.rs | 30 +++++--- rust/connlib/tunnel/src/sockets.rs | 73 ++++++++++++++------ rust/gateway/src/main.rs | 2 +- rust/gui-client/src-tauri/src/client/gui.rs | 4 +- rust/linux-client/src/main.rs | 4 +- 10 files changed, 110 insertions(+), 85 deletions(-) diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 2d2f15d41..7f0bfbbcf 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -388,28 +388,27 @@ fn connect( .enable_all() .build()?; - let sockets = Sockets::new()?; - - if let Some(ip4_socket) = sockets.ip4_socket_fd() { - callback_handler.protect_file_descriptor(ip4_socket)?; - } - if let Some(ip6_socket) = sockets.ip6_socket_fd() { - callback_handler.protect_file_descriptor(ip6_socket)?; - } + let sockets = Sockets::with_protect({ + let callbacks = callback_handler.clone(); + move |fd| { + callbacks + .protect_file_descriptor(fd) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + }); let session = Session::connect( login, sockets, private_key, Some(os_version), - callback_handler.clone(), + callback_handler, Some(MAX_PARTITION_TIME), runtime.handle().clone(), )?; Ok(SessionWrapper { inner: session, - callbacks: callback_handler, runtime, }) } @@ -463,29 +462,11 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co pub struct SessionWrapper { inner: Session, - callbacks: CallbackHandler, #[allow(dead_code)] // Only here so we don't drop the memory early. runtime: Runtime, } -impl SessionWrapper { - fn reconnect(&self) -> Result<(), CallbackError> { - let sockets = Sockets::new()?; - - if let Some(ip4_socket) = sockets.ip4_socket_fd() { - self.callbacks.protect_file_descriptor(ip4_socket)?; - } - if let Some(ip6_socket) = sockets.ip6_socket_fd() { - self.callbacks.protect_file_descriptor(ip6_socket)?; - } - - self.inner.reconnect(sockets); - - Ok(()) - } -} - /// # Safety /// Pointers must be valid #[allow(non_snake_case)] @@ -528,11 +509,9 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se #[allow(non_snake_case)] #[no_mangle] pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_reconnect( - mut env: JNIEnv, + _: JNIEnv, _: JClass, session: *const SessionWrapper, ) { - if let Err(e) = (*session).reconnect() { - throw(&mut env, "java/lang/Exception", e.to_string()); - } + (*session).inner.reconnect(); } diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index df95fce83..667271983 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -205,7 +205,7 @@ impl WrappedSession { let session = Session::connect( login, - Sockets::new().map_err(|err| err.to_string())?, + Sockets::new(), private_key, os_version_override, CallbackHandler { diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 325b58282..4adf1a85e 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -10,7 +10,7 @@ use connlib_shared::{ messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId}, Callbacks, }; -use firezone_tunnel::{ClientTunnel, Sockets}; +use firezone_tunnel::ClientTunnel; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ collections::HashMap, @@ -37,7 +37,7 @@ pub struct Eventloop { /// Commands that can be sent to the [`Eventloop`]. pub enum Command { Stop, - Reconnect(Sockets), + Reconnect, SetDns(Vec), } @@ -71,9 +71,11 @@ where tracing::warn!("Failed to update DNS: {e}"); } } - Poll::Ready(Some(Command::Reconnect(sockets))) => { + Poll::Ready(Some(Command::Reconnect)) => { self.portal.reconnect(); - self.tunnel.reconnect(sockets); + if let Err(e) = self.tunnel.reconnect() { + tracing::warn!("Failed to reconnect tunnel: {e}"); + } continue; } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 7397c6f10..b2ddd0003 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -67,11 +67,11 @@ impl Session { /// /// - Close and re-open a connection to the portal. /// - Refresh all allocations - /// - Replace the currently used [`Sockets`] with the provided one + /// - Rebind local UDP sockets /// /// # Implementation note /// - /// The reason we replace [`Sockets`] are: + /// The reason we rebind the UDP sockets are: /// /// 1. On MacOS, as socket bound to the unspecified IP cannot send to interfaces attached after the socket has been created. /// 2. Switching between networks changes the 3-tuple of the client. @@ -80,9 +80,9 @@ impl Session { /// Changing the IP would be enough for that. /// However, if the user would now change _back_ to the previous network, /// the TURN server would recognise the old allocation but the client already lost all its state associated with it. - /// To avoid race-conditions like this, we initialize a new [`Sockets`] instance which allocates a new port. - pub fn reconnect(&self, sockets: Sockets) { - let _ = self.channel.send(Command::Reconnect(sockets)); + /// To avoid race-conditions like this, we rebind the sockets to a new port. + pub fn reconnect(&self) { + let _ = self.channel.send(Command::Reconnect); } /// Sets a new set of upstream DNS servers for this [`Session`]. @@ -118,7 +118,7 @@ async fn connect( where CB: Callbacks + 'static, { - let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone()); + let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone())?; let portal = PhoenixChannel::connect( Secret::new(url), diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 5926b1e9a..a63d644a8 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -43,8 +43,13 @@ pub enum Input<'a, I> { } impl Io { - pub fn new(sockets: Sockets) -> Self { - Self { + /// Creates a new I/O abstraction + /// + /// Must be called within a Tokio runtime context so we can bind the sockets. + pub fn new(mut sockets: Sockets) -> io::Result { + sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context. + + Ok(Self { device: Device::new(), timeout: None, sockets, @@ -53,7 +58,7 @@ impl Io { Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE, ), - } + }) } pub fn poll<'b>( @@ -115,8 +120,8 @@ impl Io { &self.sockets } - pub(crate) fn set_sockets(&mut self, sockets: Sockets) { - self.sockets = sockets; + pub fn sockets_mut(&mut self) -> &mut Sockets { + &mut self.sockets } pub fn set_upstream_dns_servers( diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 0bfcc86f6..e85d27e10 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -59,21 +59,27 @@ impl ClientTunnel where CB: Callbacks + 'static, { - pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self { - Self { - io: Io::new(sockets), + pub fn new( + private_key: StaticSecret, + sockets: Sockets, + callbacks: CB, + ) -> std::io::Result { + Ok(Self { + io: Io::new(sockets)?, callbacks, role_state: ClientState::new(private_key), write_buf: Box::new([0u8; MAX_UDP_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), device_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - } + }) } - pub fn reconnect(&mut self, sockets: Sockets) { + pub fn reconnect(&mut self) -> std::io::Result<()> { self.role_state.reconnect(Instant::now()); - self.io.set_sockets(sockets); + self.io.sockets_mut().rebind()?; + + Ok(()) } pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -149,16 +155,20 @@ impl GatewayTunnel where CB: Callbacks + 'static, { - pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self { - Self { - io: Io::new(sockets), + pub fn new( + private_key: StaticSecret, + sockets: Sockets, + callbacks: CB, + ) -> std::io::Result { + Ok(Self { + io: Io::new(sockets)?, callbacks, role_state: GatewayState::new(private_key), write_buf: Box::new([0u8; MAX_UDP_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), device_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - } + }) } pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 8de835f93..78245c26a 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -13,10 +13,47 @@ use crate::Result; pub struct Sockets { socket_v4: Option, socket_v6: Option, + + #[cfg(unix)] + protect: Box io::Result<()> + Send + 'static>, +} + +impl Default for Sockets { + fn default() -> Self { + Self::new() + } } impl Sockets { - pub fn new() -> io::Result { + #[cfg(unix)] + pub fn with_protect( + protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static, + ) -> Self { + Self { + socket_v4: None, + socket_v6: None, + #[cfg(unix)] + protect: Box::new(protect), + } + } + + pub fn new() -> Self { + Self { + socket_v4: None, + socket_v6: None, + #[cfg(unix)] + protect: Box::new(|_| Ok(())), + } + } + + pub fn can_handle(&self, addr: &SocketAddr) -> bool { + match addr { + SocketAddr::V4(_) => self.socket_v4.is_some(), + SocketAddr::V6(_) => self.socket_v6.is_some(), + } + } + + pub fn rebind(&mut self) -> io::Result<()> { let socket_v4 = Socket::ip4(); let socket_v6 = Socket::ip6(); @@ -39,31 +76,23 @@ impl Sockets { _ => (), } - Ok(Self { - socket_v4: socket_v4.ok(), - socket_v6: socket_v6.ok(), - }) - } + #[cfg(unix)] + { + use std::os::fd::AsRawFd; - pub fn can_handle(&self, addr: &SocketAddr) -> bool { - match addr { - SocketAddr::V4(_) => self.socket_v4.is_some(), - SocketAddr::V6(_) => self.socket_v6.is_some(), + if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) { + (self.protect)(fd)?; + } + + if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) { + (self.protect)(fd)?; + } } - } - #[cfg(unix)] - pub fn ip4_socket_fd(&self) -> Option { - use std::os::fd::AsRawFd; + self.socket_v4 = socket_v4.ok(); + self.socket_v6 = socket_v6.ok(); - self.socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) - } - - #[cfg(unix)] - pub fn ip6_socket_fd(&self) -> Option { - use std::os::fd::AsRawFd; - - self.socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) + Ok(()) } /// Flushes all buffered data on the sockets. diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 86f6f8ff0..e280bcc2e 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -91,7 +91,7 @@ async fn get_firezone_id(env_id: Option) -> Result { } async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { - let mut tunnel = GatewayTunnel::new(private_key, Sockets::new()?, CallbackHandler); + let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?; let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>( Secret::new(login), diff --git a/rust/gui-client/src-tauri/src/client/gui.rs b/rust/gui-client/src-tauri/src/client/gui.rs index 70f1df6c7..efa8f699a 100644 --- a/rust/gui-client/src-tauri/src/client/gui.rs +++ b/rust/gui-client/src-tauri/src/client/gui.rs @@ -536,7 +536,7 @@ impl Controller { )?; let connlib = connlib_client_shared::Session::connect( login, - Sockets::new()?, + Sockets::new(), private_key, None, callback_handler.clone(), @@ -837,7 +837,7 @@ async fn run_controller( have_internet = new_have_internet; if let Some(session) = controller.session.as_mut() { tracing::debug!("Internet up/down changed, calling `Session::reconnect`"); - session.connlib.reconnect(Sockets::new()?); + session.connlib.reconnect(); } } }, diff --git a/rust/linux-client/src/main.rs b/rust/linux-client/src/main.rs index e9988573b..35ae40dac 100644 --- a/rust/linux-client/src/main.rs +++ b/rust/linux-client/src/main.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { let session = Session::connect( login, - Sockets::new()?, + Sockets::new(), private_key, None, callbacks.clone(), @@ -62,7 +62,7 @@ async fn main() -> Result<()> { if sighup.poll_recv(cx).is_ready() { tracing::debug!("Received SIGHUP"); - session.reconnect(Sockets::new()?); + session.reconnect(); continue; }