From 9bc60dc6182e6e5753da09f2013920124bddbacc Mon Sep 17 00:00:00 2001 From: Reactor Scram Date: Wed, 4 Sep 2024 14:20:45 -0500 Subject: [PATCH] refactor(rust/gui-client): close callbacks when closing connlib (#6590) Closes #6576 This recreates the callback channel on every connect / disconnect cycle, to prevent this sequence: 1. Start connlib 2. Fail in `make_tun` 3. Spend several seconds doing platform-specific things 4. Stop connlib (since `make_tun` failed) 5. Come back to the main loop to find a bunch of queued-up callbacks even though connlib is supposed to be stopped. Instead we get: 5\. Come back to the main loop and we've dropped the callback receiver, so any callbacks that connlib sent while we were busy are either dropped or not even sent. --- rust/headless-client/src/ipc_service.rs | 113 ++++++++++++++---------- 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index cab0d832f..43b984527 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::{bail, Context as _, Result}; use clap::Parser; -use connlib_client_shared::{keypair, ConnectArgs, LoginUrl, Session}; +use connlib_client_shared::{keypair, ConnectArgs, LoginUrl}; use connlib_shared::callbacks::ResourceDescription; use firezone_bin_shared::{ platform::{tcp_socket_factory, udp_socket_factory, DnsControlMethod}, @@ -236,17 +236,20 @@ async fn ipc_listen( /// Handles one IPC client struct Handler<'a> { - callback_handler: CallbackHandler, - cb_rx: mpsc::Receiver, - connlib: Option, dns_controller: &'a mut DnsController, ipc_rx: ipc::ServerRead, ipc_tx: ipc::ServerWrite, last_connlib_start_instant: Option, log_filter_reloader: &'a LogFilterReloader, + session: Option, tun_device: TunDeviceManager, } +struct Session { + cb_rx: mpsc::Receiver, + connlib: connlib_client_shared::Session, +} + enum Event { Callback(ConnlibMsg), CallbackChannelClosed, @@ -275,18 +278,15 @@ impl<'a> Handler<'a> { .next_client_split() .await .context("Failed to wait for incoming IPC connection from a GUI")?; - let (cb_tx, cb_rx) = mpsc::channel(1_000); let tun_device = TunDeviceManager::new(DEFAULT_MTU)?; Ok(Self { - callback_handler: CallbackHandler { cb_tx }, - cb_rx, - connlib: None, dns_controller, ipc_rx, ipc_tx, last_connlib_start_instant: None, log_filter_reloader, + session: None, tun_device, }) } @@ -353,18 +353,20 @@ impl<'a> Handler<'a> { } // `FramedRead::next` is cancel-safe. if let Poll::Ready(result) = pin!(&mut self.ipc_rx).poll_next(cx) { - return match result { - Some(Ok(x)) => Poll::Ready(Event::Ipc(x)), - Some(Err(error)) => Poll::Ready(Event::IpcError(error)), - None => Poll::Ready(Event::IpcDisconnected), - }; + return Poll::Ready(match result { + Some(Ok(x)) => Event::Ipc(x), + Some(Err(error)) => Event::IpcError(error), + None => Event::IpcDisconnected, + }); } - // `tokio::sync::mpsc::Receiver::recv` is cancel-safe. - if let Poll::Ready(option) = self.cb_rx.poll_recv(cx) { - return match option { - Some(x) => Poll::Ready(Event::Callback(x)), - None => Poll::Ready(Event::CallbackChannelClosed), - }; + if let Some(session) = self.session.as_mut() { + // `tokio::sync::mpsc::Receiver::recv` is cancel-safe. + if let Poll::Ready(option) = session.cb_rx.poll_recv(cx) { + return Poll::Ready(match option { + Some(x) => Event::Callback(x), + None => Event::CallbackChannelClosed, + }); + } } Poll::Pending } @@ -409,10 +411,6 @@ impl<'a> Handler<'a> { Ok(()) } - fn tunnel_is_ready(&self) -> bool { - self.last_connlib_start_instant.is_none() && self.connlib.is_some() - } - async fn handle_ipc_msg(&mut self, msg: ClientMsg) -> Result<()> { match msg { ClientMsg::ClearLogs => { @@ -438,36 +436,48 @@ impl<'a> Handler<'a> { .context("Failed to send `ConnectResult`")? } ClientMsg::Disconnect => { - if let Some(connlib) = self.connlib.take() { - connlib.disconnect(); - self.dns_controller.deactivate()?; - } else { + let Some(session) = self.session.take() else { tracing::error!("Error - Got Disconnect when we're already not connected"); - } + return Ok(()); + }; + + // Identical to dropping it, but looks nicer. + session.connlib.disconnect(); + self.dns_controller.deactivate()?; } ClientMsg::ReloadLogFilter => { let filter = spawn_blocking(get_log_filter).await??; self.log_filter_reloader.reload(filter)?; } ClientMsg::Reset => { - if self.tunnel_is_ready() { - self.connlib.as_mut().context("No connlib session")?.reset(); - } else { - tracing::debug!("Ignoring redundant reset"); + if self.last_connlib_start_instant.is_some() { + tracing::debug!("Ignoring reset since we're still signing in"); + return Ok(()); } + let Some(session) = self.session.as_ref() else { + tracing::debug!("Cannot reset if we're signed out"); + return Ok(()); + }; + + session.connlib.reset(); } ClientMsg::SetDns(resolvers) => { + let Some(session) = self.session.as_ref() else { + tracing::debug!("Cannot set DNS resolvers if we're signed out"); + return Ok(()); + }; + tracing::debug!(?resolvers); - self.connlib - .as_mut() - .context("No connlib session")? - .set_dns(resolvers) + session.connlib.set_dns(resolvers); } ClientMsg::SetDisabledResources(disabled_resources) => { - self.connlib - .as_mut() - .context("No connlib session")? - .set_disabled_resources(disabled_resources); + let Some(session) = self.session.as_ref() else { + // At this point, the GUI has already saved the disabled Resources to disk, so it'll be correct on the next sign-in anyway. + tracing::debug!("Cannot set disabled resources if we're signed out"); + return Ok(()); + }; + + session.connlib.set_disabled_resources(disabled_resources); } } Ok(()) @@ -479,10 +489,7 @@ impl<'a> Handler<'a> { /// /// Throws matchable errors for bad URLs, unable to reach the portal, or unable to create the tunnel device fn connect_to_firezone(&mut self, api_url: &str, token: SecretString) -> Result<(), Error> { - // There isn't an airtight way to implement a "disconnect and reconnect" - // right now because `Session::disconnect` is fire-and-forget: - // - assert!(self.connlib.is_none()); + assert!(self.session.is_none()); let device_id = device_id::get_or_create().map_err(|e| Error::DeviceId(e.to_string()))?; let (private_key, public_key) = keypair(); @@ -496,11 +503,13 @@ impl<'a> Handler<'a> { .map_err(|e| Error::LoginUrl(e.to_string()))?; self.last_connlib_start_instant = Some(Instant::now()); + let (cb_tx, cb_rx) = mpsc::channel(1_000); + let callbacks = CallbackHandler { cb_tx }; let args = ConnectArgs { tcp_socket_factory: Arc::new(tcp_socket_factory), udp_socket_factory: Arc::new(udp_socket_factory), private_key, - callbacks: self.callback_handler.clone(), + callbacks, }; // Synchronous DNS resolution here @@ -518,16 +527,22 @@ impl<'a> Handler<'a> { // Read the resolvers before starting connlib, in case connlib's startup interferes. let dns = self.dns_controller.system_resolvers(); - let new_session = Session::connect(args, portal, tokio::runtime::Handle::current()); + let connlib = connlib_client_shared::Session::connect( + args, + portal, + tokio::runtime::Handle::current(), + ); // Call `set_dns` before `set_tun` so that the tunnel starts up with a valid list of resolvers. tracing::debug!(?dns, "Calling `set_dns`..."); - new_session.set_dns(dns); + connlib.set_dns(dns); let tun = self .tun_device .make_tun() .map_err(|e| Error::TunnelDevice(e.to_string()))?; - new_session.set_tun(Box::new(tun)); - self.connlib = Some(new_session); + connlib.set_tun(Box::new(tun)); + + let session = Session { cb_rx, connlib }; + self.session = Some(session); Ok(()) }