From 4f92a0d7cac56468ce3a06a8eea7db32228ea3b7 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 2 Dec 2024 20:07:44 +0000 Subject: [PATCH] refactor(gui-client): tidy up GUI controller code (#7444) This PR intends to be a pure refactoring, i.e. no behaviour change. It simplifies a few aspects of the GUI controller event-loop by getting rid of the `select!` macro. We also remove some indirection of the `gui_controller::Builder`. --- rust/Cargo.lock | 1 + .../bin-shared/src/network_changes/windows.rs | 5 +- rust/gui-client/src-common/Cargo.toml | 1 + rust/gui-client/src-common/src/controller.rs | 272 +++++++++++------- rust/gui-client/src-tauri/src/client/gui.rs | 35 +-- rust/headless-client/src/main.rs | 9 +- 6 files changed, 170 insertions(+), 153 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c9a08b2ff..116e650b7 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2049,6 +2049,7 @@ dependencies = [ "thiserror", "time", "tokio", + "tokio-stream", "tracing", "tracing-log", "tracing-subscriber", diff --git a/rust/bin-shared/src/network_changes/windows.rs b/rust/bin-shared/src/network_changes/windows.rs index 72e0c7e0b..b341a268c 100644 --- a/rust/bin-shared/src/network_changes/windows.rs +++ b/rust/bin-shared/src/network_changes/windows.rs @@ -117,8 +117,9 @@ pub struct Worker { impl Drop for Worker { fn drop(&mut self) { - self.close() - .expect("should be able to close WorkerInner cleanly"); + if let Err(e) = self.close() { + tracing::error!(error = anyhow_dyn_err(&e), "Failed to close worker thread") + } } } diff --git a/rust/gui-client/src-common/Cargo.toml b/rust/gui-client/src-common/Cargo.toml index 5c650aee8..44755a056 100644 --- a/rust/gui-client/src-common/Cargo.toml +++ b/rust/gui-client/src-common/Cargo.toml @@ -31,6 +31,7 @@ subtle = { workspace = true } thiserror = { workspace = true } time = { workspace = true, features = ["formatting"] } tokio = { workspace = true } +tokio-stream = { workspace = true } tracing = { workspace = true } tracing-log = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/rust/gui-client/src-common/src/controller.rs b/rust/gui-client/src-common/src/controller.rs index 4247c1b0a..0fcf7bda1 100644 --- a/rust/gui-client/src-common/src/controller.rs +++ b/rust/gui-client/src-common/src/controller.rs @@ -8,16 +8,21 @@ use crate::{ }; use anyhow::{anyhow, Context, Result}; use connlib_model::ResourceView; -use firezone_bin_shared::{new_dns_notifier, new_network_notifier}; +use firezone_bin_shared::platform::DnsControlMethod; use firezone_headless_client::{ IpcClientMsg::{self, SetDisabledResources}, IpcServerMsg, IpcServiceError, LogFilterReloader, }; use firezone_logging::{anyhow_dyn_err, std_dyn_err}; use firezone_telemetry::Telemetry; +use futures::{ + stream::{self, BoxStream}, + Stream, StreamExt, +}; use secrecy::{ExposeSecret as _, SecretString}; -use std::{collections::BTreeSet, ops::ControlFlow, path::PathBuf, time::Instant}; +use std::{collections::BTreeSet, ops::ControlFlow, path::PathBuf, task::Poll, time::Instant}; use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; use url::Url; use ControllerRequest as Req; @@ -34,60 +39,19 @@ pub struct Controller<'a, I: GuiIntegration> { clear_logs_callback: Option>>, ctlr_tx: CtlrTx, ipc_client: ipc::Client, - ipc_rx: mpsc::Receiver, + ipc_rx: ReceiverStream, integration: I, log_filter_reloader: LogFilterReloader, /// A release that's ready to download release: Option, - rx: mpsc::Receiver, + rx: ReceiverStream, status: Status, telemetry: &'a mut Telemetry, - updates_rx: mpsc::Receiver>, + updates_rx: ReceiverStream>, uptime: crate::uptime::Tracker, -} -pub struct Builder<'a, I: GuiIntegration> { - pub advanced_settings: AdvancedSettings, - pub ctlr_tx: CtlrTx, - pub integration: I, - pub log_filter_reloader: LogFilterReloader, - pub rx: mpsc::Receiver, - pub telemetry: &'a mut Telemetry, - pub updates_rx: mpsc::Receiver>, -} - -impl<'a, I: GuiIntegration> Builder<'a, I> { - pub async fn build(self) -> Result> { - let Builder { - advanced_settings, - ctlr_tx, - integration, - log_filter_reloader, - rx, - telemetry, - updates_rx, - } = self; - - let (ipc_tx, ipc_rx) = mpsc::channel(1); - let ipc_client = ipc::Client::new(ipc_tx).await?; - - Ok(Controller { - advanced_settings, - auth: auth::Auth::new()?, - clear_logs_callback: None, - ctlr_tx, - ipc_client, - ipc_rx, - integration, - log_filter_reloader, - release: None, - rx, - status: Default::default(), - telemetry, - updates_rx, - uptime: Default::default(), - }) - } + dns_notifier: BoxStream<'static, Result<()>>, + network_notifier: BoxStream<'static, Result<()>>, } pub trait GuiIntegration { @@ -203,7 +167,56 @@ impl Status { } } +enum EventloopTick { + NetworkChanged(Result<()>), + DnsChanged(Result<()>), + IpcEvent(ipc::Event), + ControllerRequest(Option), + UpdateNotification(Option>), +} + impl<'a, I: GuiIntegration> Controller<'a, I> { + pub async fn start( + ctlr_tx: CtlrTx, + integration: I, + rx: mpsc::Receiver, + advanced_settings: AdvancedSettings, + log_filter_reloader: LogFilterReloader, + telemetry: &mut Telemetry, + updates_rx: mpsc::Receiver>, + ) -> Result<(), Error> { + tracing::debug!("Starting new instance of `Controller`"); + + let (ipc_tx, ipc_rx) = mpsc::channel(1); + let ipc_client = ipc::Client::new(ipc_tx).await?; + + let dns_notifier = new_dns_notifier().await?.boxed(); + let network_notifier = new_network_notifier().await?.boxed(); + + let controller = Controller { + advanced_settings, + auth: auth::Auth::new()?, + clear_logs_callback: None, + ctlr_tx, + ipc_client, + ipc_rx: ReceiverStream::new(ipc_rx), + integration, + log_filter_reloader, + release: None, + rx: ReceiverStream::new(rx), + status: Default::default(), + telemetry, + updates_rx: ReceiverStream::new(updates_rx), + uptime: Default::default(), + dns_notifier, + network_notifier, + }; + + controller.main_loop().await?; + + Ok(()) + } + pub async fn main_loop(mut self) -> Result<(), Error> { let account_slug = self.auth.session().map(|s| s.account_slug.to_owned()); @@ -242,80 +255,53 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { self.integration.set_welcome_window_visible(true)?; } - let tokio_handle = tokio::runtime::Handle::current(); - let dns_control_method = Default::default(); - - let mut dns_notifier = new_dns_notifier(tokio_handle.clone(), dns_control_method).await?; - let mut network_notifier = - new_network_notifier(tokio_handle.clone(), dns_control_method).await?; - drop(tokio_handle); - - loop { - // TODO: Add `ControllerRequest::NetworkChange` and `DnsChange` and replace - // `tokio::select!` with a `poll_*` function - tokio::select! { - result = network_notifier.notified() => { - result?; + while let Some(tick) = self.tick().await { + match tick { + EventloopTick::NetworkChanged(Ok(())) => { if self.status.needs_network_changes() { tracing::debug!("Internet up/down changed, calling `Session::reset`"); self.ipc_client.reset().await? } + self.try_retry_connection().await? } - result = dns_notifier.notified() => { - result?; + EventloopTick::DnsChanged(Ok(())) => { if self.status.needs_network_changes() { - let resolvers = firezone_headless_client::dns_control::system_resolvers_for_gui()?; - tracing::debug!(?resolvers, "New DNS resolvers, calling `Session::set_dns`"); + let resolvers = + firezone_headless_client::dns_control::system_resolvers_for_gui()?; + tracing::debug!( + ?resolvers, + "New DNS resolvers, calling `Session::set_dns`" + ); self.ipc_client.set_dns(resolvers).await?; } + self.try_retry_connection().await? } - event = self.ipc_rx.recv() => { - let event = event.context("IPC task stopped")?; + EventloopTick::NetworkChanged(Err(e)) | EventloopTick::DnsChanged(Err(e)) => { + return Err(Error::Other(e)) + } + EventloopTick::IpcEvent(event) => { if let ControlFlow::Break(()) = self.handle_ipc_event(event).await? { break; } } - req = self.rx.recv() => { - let Some(req) = req else { - tracing::warn!("Controller channel closed, breaking main loop."); - break; - }; - - #[expect(clippy::wildcard_enum_match_arm)] - match req { - // SAFETY: Crashing is unsafe - Req::Fail(Failure::Crash) => { - tracing::error!("Crashing on purpose"); - unsafe { sadness_generator::raise_segfault() } - }, - Req::Fail(Failure::Error) => Err(anyhow!("Test error"))?, - Req::Fail(Failure::Panic) => panic!("Test panic"), - Req::SystemTrayMenu(TrayMenuEvent::Quit) => { - tracing::info!("User clicked Quit in the menu"); - self.status = Status::Quitting; - self.ipc_client.send_msg(&IpcClientMsg::Disconnect).await?; - self.refresh_system_tray_menu()?; - } - // TODO: Should we really skip cleanup if a request fails? - req => self.handle_request(req).await?, - } + EventloopTick::ControllerRequest(Some(req)) => self.handle_request(req).await?, + EventloopTick::ControllerRequest(None) => { + tracing::warn!("Controller channel closed, breaking main loop"); + break; + } + EventloopTick::UpdateNotification(Some(notification)) => { + self.handle_update_notification(notification)? + } + EventloopTick::UpdateNotification(None) => { + return Err(Error::Other(anyhow!("Update checker task stopped"))) } - notification = self.updates_rx.recv() => self.handle_update_notification(notification.context("Update checker task stopped")?)?, } - // Code down here may not run because the `select` sometimes `continue`s. } tracing::debug!("Closing..."); - if let Err(error) = dns_notifier.close() { - tracing::error!(error = anyhow_dyn_err(&error), "dns_notifier"); - } - if let Err(error) = network_notifier.close() { - tracing::error!(error = anyhow_dyn_err(&error), "network_notifier"); - } - if let Err(error) = self.ipc_client.disconnect_from_ipc().await { tracing::error!(error = anyhow_dyn_err(&error), "ipc_client"); } @@ -325,6 +311,35 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { Ok(()) } + async fn tick(&mut self) -> Option { + std::future::poll_fn(|cx| { + if let Poll::Ready(Some(res)) = self.dns_notifier.poll_next_unpin(cx) { + return Poll::Ready(Some(EventloopTick::DnsChanged(res))); + } + + if let Poll::Ready(Some(res)) = self.network_notifier.poll_next_unpin(cx) { + return Poll::Ready(Some(EventloopTick::NetworkChanged(res))); + } + + if let Poll::Ready(maybe_ipc) = self.ipc_rx.poll_next_unpin(cx) { + return Poll::Ready(Some(EventloopTick::IpcEvent( + maybe_ipc.unwrap_or(ipc::Event::Closed), + ))); + } + + if let Poll::Ready(maybe_req) = self.rx.poll_next_unpin(cx) { + return Poll::Ready(Some(EventloopTick::ControllerRequest(maybe_req))); + } + + if let Poll::Ready(notification) = self.updates_rx.poll_next_unpin(cx) { + return Poll::Ready(Some(EventloopTick::UpdateNotification(notification))); + } + + Poll::Pending + }) + .await + } + async fn start_session(&mut self, token: SecretString) -> Result<(), Error> { match self.status { Status::Disconnected | Status::RetryingConnection { .. } => {} @@ -397,9 +412,13 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { Req::ExportLogs { path, stem } => logging::export_logs_to(path, stem) .await .context("Failed to export logs to zip")?, - Req::Fail(_) => Err(anyhow!( - "Impossible error: `Fail` should be handled before this" - ))?, + Req::Fail(Failure::Crash) => { + tracing::error!("Crashing on purpose"); + // SAFETY: Crashing is unsafe + unsafe { sadness_generator::raise_segfault() } + }, + Req::Fail(Failure::Error) => Err(anyhow!("Test error"))?, + Req::Fail(Failure::Panic) => panic!("Test panic"), Req::GetAdvancedSettings(tx) => { tx.send(self.advanced_settings.clone()).ok(); } @@ -482,9 +501,12 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { self.integration.open_url(&url) .context("Couldn't open URL from system tray")? } - Req::SystemTrayMenu(TrayMenuEvent::Quit) => Err(anyhow!( - "Impossible error: `Quit` should be handled before this" - ))?, + Req::SystemTrayMenu(TrayMenuEvent::Quit) => { + tracing::info!("User clicked Quit in the menu"); + self.status = Status::Quitting; + self.ipc_client.send_msg(&IpcClientMsg::Disconnect).await?; + self.refresh_system_tray_menu()?; + } Req::UpdateNotificationClicked(download_url) => { tracing::info!("UpdateNotificationClicked in run_controller!"); self.integration.open_url(&download_url) @@ -502,16 +524,18 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { Err(Error::ConnectToFirezoneFailed(error)) => { tracing::error!("Failed to connect to Firezone: {error}"); self.sign_out().await?; + Ok(ControlFlow::Continue(())) } - Err(error) => Err(error)?, + Err(error) => Err(error), }, ipc::Event::ReadFailed(error) => { // IPC errors are always fatal tracing::error!(error = anyhow_dyn_err(&error), "IPC read failure"); - Err(Error::IpcRead)? + + Err(Error::IpcRead) } - ipc::Event::Closed => Err(Error::IpcClosed)?, + ipc::Event::Closed => Err(Error::IpcClosed), } } @@ -778,3 +802,31 @@ impl<'a, I: GuiIntegration> Controller<'a, I> { Ok(()) } } + +async fn new_dns_notifier() -> Result>> { + let worker = firezone_bin_shared::new_dns_notifier( + tokio::runtime::Handle::current(), + DnsControlMethod::default(), + ) + .await?; + + Ok(stream::try_unfold(worker, |mut worker| async move { + let () = worker.notified().await?; + + Ok(Some(((), worker))) + })) +} + +async fn new_network_notifier() -> Result>> { + let worker = firezone_bin_shared::new_network_notifier( + tokio::runtime::Handle::current(), + DnsControlMethod::default(), + ) + .await?; + + Ok(stream::try_unfold(worker, |mut worker| async move { + let () = worker.notified().await?; + + Ok(Some(((), worker))) + })) +} diff --git a/rust/gui-client/src-tauri/src/client/gui.rs b/rust/gui-client/src-tauri/src/client/gui.rs index 47db53190..9d180a7c3 100644 --- a/rust/gui-client/src-tauri/src/client/gui.rs +++ b/rust/gui-client/src-tauri/src/client/gui.rs @@ -11,7 +11,7 @@ use anyhow::{bail, Context, Result}; use common::system_tray::Event as TrayMenuEvent; use firezone_gui_client_common::{ self as common, - controller::{ControllerRequest, CtlrTx, GuiIntegration}, + controller::{Controller, ControllerRequest, CtlrTx, GuiIntegration}, deep_link, errors::{self, Error}, settings::AdvancedSettings, @@ -236,7 +236,7 @@ pub(crate) fn run( let app_handle = app.handle().clone(); let _ctlr_task = tokio::spawn(async move { - let result = AssertUnwindSafe(run_controller( + let result = AssertUnwindSafe(Controller::start( ctlr_tx, integration, ctlr_rx, @@ -467,34 +467,3 @@ fn handle_system_tray_event(app: &tauri::AppHandle, event: TrayMenuEvent) -> Res .blocking_send(ControllerRequest::SystemTrayMenu(event))?; Ok(()) } - -// TODO: Move this into `impl Controller` -async fn run_controller( - ctlr_tx: CtlrTx, - integration: TauriIntegration, - rx: mpsc::Receiver, - advanced_settings: AdvancedSettings, - log_filter_reloader: LogFilterReloader, - telemetry: &mut telemetry::Telemetry, - updates_rx: mpsc::Receiver>, -) -> Result<(), Error> { - tracing::debug!("Entered `run_controller`"); - - let controller = firezone_gui_client_common::controller::Builder { - advanced_settings, - ctlr_tx, - integration, - log_filter_reloader, - rx, - telemetry, - updates_rx, - } - .build() - .await?; - - controller.main_loop().await?; - - // Last chance to do any drops / cleanup before the process crashes. - - Ok(()) -} diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index cbddd927d..d1a3144f9 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -14,7 +14,7 @@ use firezone_bin_shared::{ use firezone_headless_client::{ device_id, signals, CallbackHandler, CliCommon, ConnlibMsg, DnsController, }; -use firezone_logging::{anyhow_dyn_err, telemetry_span}; +use firezone_logging::telemetry_span; use firezone_telemetry::Telemetry; use futures::StreamExt as _; use phoenix_channel::get_user_agent; @@ -316,13 +316,6 @@ fn main() -> Result<()> { } }; - if let Err(error) = dns_notifier.close() { - tracing::error!(error = anyhow_dyn_err(&error), "DNS notifier") - } - if let Err(error) = network_notifier.close() { - tracing::error!(error = anyhow_dyn_err(&error), "network notifier"); - } - telemetry.stop().await; // Stop telemetry before dropping session. `connlib` needs to be active for this, otherwise we won't be able to resolve the DNS name for sentry. session.disconnect();