refactor(gui-client): tidy up GUI controller code (#7444)

This PR intends to be a pure refactoring, i.e. no behaviour change. It
simplifies a few aspects of the GUI controller event-loop by getting rid
of the `select!` macro. We also remove some indirection of the
`gui_controller::Builder`.
This commit is contained in:
Thomas Eizinger
2024-12-02 20:07:44 +00:00
committed by GitHub
parent 8bc1277c24
commit 4f92a0d7ca
6 changed files with 170 additions and 153 deletions

1
rust/Cargo.lock generated
View File

@@ -2049,6 +2049,7 @@ dependencies = [
"thiserror",
"time",
"tokio",
"tokio-stream",
"tracing",
"tracing-log",
"tracing-subscriber",

View File

@@ -117,8 +117,9 @@ pub struct Worker {
impl Drop for Worker {
fn drop(&mut self) {
self.close()
.expect("should be able to close WorkerInner cleanly");
if let Err(e) = self.close() {
tracing::error!(error = anyhow_dyn_err(&e), "Failed to close worker thread")
}
}
}

View File

@@ -31,6 +31,7 @@ subtle = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true, features = ["formatting"] }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tracing = { workspace = true }
tracing-log = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }

View File

@@ -8,16 +8,21 @@ use crate::{
};
use anyhow::{anyhow, Context, Result};
use connlib_model::ResourceView;
use firezone_bin_shared::{new_dns_notifier, new_network_notifier};
use firezone_bin_shared::platform::DnsControlMethod;
use firezone_headless_client::{
IpcClientMsg::{self, SetDisabledResources},
IpcServerMsg, IpcServiceError, LogFilterReloader,
};
use firezone_logging::{anyhow_dyn_err, std_dyn_err};
use firezone_telemetry::Telemetry;
use futures::{
stream::{self, BoxStream},
Stream, StreamExt,
};
use secrecy::{ExposeSecret as _, SecretString};
use std::{collections::BTreeSet, ops::ControlFlow, path::PathBuf, time::Instant};
use std::{collections::BTreeSet, ops::ControlFlow, path::PathBuf, task::Poll, time::Instant};
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use url::Url;
use ControllerRequest as Req;
@@ -34,60 +39,19 @@ pub struct Controller<'a, I: GuiIntegration> {
clear_logs_callback: Option<oneshot::Sender<Result<(), String>>>,
ctlr_tx: CtlrTx,
ipc_client: ipc::Client,
ipc_rx: mpsc::Receiver<ipc::Event>,
ipc_rx: ReceiverStream<ipc::Event>,
integration: I,
log_filter_reloader: LogFilterReloader,
/// A release that's ready to download
release: Option<updates::Release>,
rx: mpsc::Receiver<ControllerRequest>,
rx: ReceiverStream<ControllerRequest>,
status: Status,
telemetry: &'a mut Telemetry,
updates_rx: mpsc::Receiver<Option<updates::Notification>>,
updates_rx: ReceiverStream<Option<updates::Notification>>,
uptime: crate::uptime::Tracker,
}
pub struct Builder<'a, I: GuiIntegration> {
pub advanced_settings: AdvancedSettings,
pub ctlr_tx: CtlrTx,
pub integration: I,
pub log_filter_reloader: LogFilterReloader,
pub rx: mpsc::Receiver<ControllerRequest>,
pub telemetry: &'a mut Telemetry,
pub updates_rx: mpsc::Receiver<Option<updates::Notification>>,
}
impl<'a, I: GuiIntegration> Builder<'a, I> {
pub async fn build(self) -> Result<Controller<'a, I>> {
let Builder {
advanced_settings,
ctlr_tx,
integration,
log_filter_reloader,
rx,
telemetry,
updates_rx,
} = self;
let (ipc_tx, ipc_rx) = mpsc::channel(1);
let ipc_client = ipc::Client::new(ipc_tx).await?;
Ok(Controller {
advanced_settings,
auth: auth::Auth::new()?,
clear_logs_callback: None,
ctlr_tx,
ipc_client,
ipc_rx,
integration,
log_filter_reloader,
release: None,
rx,
status: Default::default(),
telemetry,
updates_rx,
uptime: Default::default(),
})
}
dns_notifier: BoxStream<'static, Result<()>>,
network_notifier: BoxStream<'static, Result<()>>,
}
pub trait GuiIntegration {
@@ -203,7 +167,56 @@ impl Status {
}
}
enum EventloopTick {
NetworkChanged(Result<()>),
DnsChanged(Result<()>),
IpcEvent(ipc::Event),
ControllerRequest(Option<ControllerRequest>),
UpdateNotification(Option<Option<updates::Notification>>),
}
impl<'a, I: GuiIntegration> Controller<'a, I> {
pub async fn start(
ctlr_tx: CtlrTx,
integration: I,
rx: mpsc::Receiver<ControllerRequest>,
advanced_settings: AdvancedSettings,
log_filter_reloader: LogFilterReloader,
telemetry: &mut Telemetry,
updates_rx: mpsc::Receiver<Option<updates::Notification>>,
) -> Result<(), Error> {
tracing::debug!("Starting new instance of `Controller`");
let (ipc_tx, ipc_rx) = mpsc::channel(1);
let ipc_client = ipc::Client::new(ipc_tx).await?;
let dns_notifier = new_dns_notifier().await?.boxed();
let network_notifier = new_network_notifier().await?.boxed();
let controller = Controller {
advanced_settings,
auth: auth::Auth::new()?,
clear_logs_callback: None,
ctlr_tx,
ipc_client,
ipc_rx: ReceiverStream::new(ipc_rx),
integration,
log_filter_reloader,
release: None,
rx: ReceiverStream::new(rx),
status: Default::default(),
telemetry,
updates_rx: ReceiverStream::new(updates_rx),
uptime: Default::default(),
dns_notifier,
network_notifier,
};
controller.main_loop().await?;
Ok(())
}
pub async fn main_loop(mut self) -> Result<(), Error> {
let account_slug = self.auth.session().map(|s| s.account_slug.to_owned());
@@ -242,80 +255,53 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
self.integration.set_welcome_window_visible(true)?;
}
let tokio_handle = tokio::runtime::Handle::current();
let dns_control_method = Default::default();
let mut dns_notifier = new_dns_notifier(tokio_handle.clone(), dns_control_method).await?;
let mut network_notifier =
new_network_notifier(tokio_handle.clone(), dns_control_method).await?;
drop(tokio_handle);
loop {
// TODO: Add `ControllerRequest::NetworkChange` and `DnsChange` and replace
// `tokio::select!` with a `poll_*` function
tokio::select! {
result = network_notifier.notified() => {
result?;
while let Some(tick) = self.tick().await {
match tick {
EventloopTick::NetworkChanged(Ok(())) => {
if self.status.needs_network_changes() {
tracing::debug!("Internet up/down changed, calling `Session::reset`");
self.ipc_client.reset().await?
}
self.try_retry_connection().await?
}
result = dns_notifier.notified() => {
result?;
EventloopTick::DnsChanged(Ok(())) => {
if self.status.needs_network_changes() {
let resolvers = firezone_headless_client::dns_control::system_resolvers_for_gui()?;
tracing::debug!(?resolvers, "New DNS resolvers, calling `Session::set_dns`");
let resolvers =
firezone_headless_client::dns_control::system_resolvers_for_gui()?;
tracing::debug!(
?resolvers,
"New DNS resolvers, calling `Session::set_dns`"
);
self.ipc_client.set_dns(resolvers).await?;
}
self.try_retry_connection().await?
}
event = self.ipc_rx.recv() => {
let event = event.context("IPC task stopped")?;
EventloopTick::NetworkChanged(Err(e)) | EventloopTick::DnsChanged(Err(e)) => {
return Err(Error::Other(e))
}
EventloopTick::IpcEvent(event) => {
if let ControlFlow::Break(()) = self.handle_ipc_event(event).await? {
break;
}
}
req = self.rx.recv() => {
let Some(req) = req else {
tracing::warn!("Controller channel closed, breaking main loop.");
break;
};
#[expect(clippy::wildcard_enum_match_arm)]
match req {
// SAFETY: Crashing is unsafe
Req::Fail(Failure::Crash) => {
tracing::error!("Crashing on purpose");
unsafe { sadness_generator::raise_segfault() }
},
Req::Fail(Failure::Error) => Err(anyhow!("Test error"))?,
Req::Fail(Failure::Panic) => panic!("Test panic"),
Req::SystemTrayMenu(TrayMenuEvent::Quit) => {
tracing::info!("User clicked Quit in the menu");
self.status = Status::Quitting;
self.ipc_client.send_msg(&IpcClientMsg::Disconnect).await?;
self.refresh_system_tray_menu()?;
}
// TODO: Should we really skip cleanup if a request fails?
req => self.handle_request(req).await?,
}
EventloopTick::ControllerRequest(Some(req)) => self.handle_request(req).await?,
EventloopTick::ControllerRequest(None) => {
tracing::warn!("Controller channel closed, breaking main loop");
break;
}
EventloopTick::UpdateNotification(Some(notification)) => {
self.handle_update_notification(notification)?
}
EventloopTick::UpdateNotification(None) => {
return Err(Error::Other(anyhow!("Update checker task stopped")))
}
notification = self.updates_rx.recv() => self.handle_update_notification(notification.context("Update checker task stopped")?)?,
}
// Code down here may not run because the `select` sometimes `continue`s.
}
tracing::debug!("Closing...");
if let Err(error) = dns_notifier.close() {
tracing::error!(error = anyhow_dyn_err(&error), "dns_notifier");
}
if let Err(error) = network_notifier.close() {
tracing::error!(error = anyhow_dyn_err(&error), "network_notifier");
}
if let Err(error) = self.ipc_client.disconnect_from_ipc().await {
tracing::error!(error = anyhow_dyn_err(&error), "ipc_client");
}
@@ -325,6 +311,35 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
Ok(())
}
async fn tick(&mut self) -> Option<EventloopTick> {
std::future::poll_fn(|cx| {
if let Poll::Ready(Some(res)) = self.dns_notifier.poll_next_unpin(cx) {
return Poll::Ready(Some(EventloopTick::DnsChanged(res)));
}
if let Poll::Ready(Some(res)) = self.network_notifier.poll_next_unpin(cx) {
return Poll::Ready(Some(EventloopTick::NetworkChanged(res)));
}
if let Poll::Ready(maybe_ipc) = self.ipc_rx.poll_next_unpin(cx) {
return Poll::Ready(Some(EventloopTick::IpcEvent(
maybe_ipc.unwrap_or(ipc::Event::Closed),
)));
}
if let Poll::Ready(maybe_req) = self.rx.poll_next_unpin(cx) {
return Poll::Ready(Some(EventloopTick::ControllerRequest(maybe_req)));
}
if let Poll::Ready(notification) = self.updates_rx.poll_next_unpin(cx) {
return Poll::Ready(Some(EventloopTick::UpdateNotification(notification)));
}
Poll::Pending
})
.await
}
async fn start_session(&mut self, token: SecretString) -> Result<(), Error> {
match self.status {
Status::Disconnected | Status::RetryingConnection { .. } => {}
@@ -397,9 +412,13 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
Req::ExportLogs { path, stem } => logging::export_logs_to(path, stem)
.await
.context("Failed to export logs to zip")?,
Req::Fail(_) => Err(anyhow!(
"Impossible error: `Fail` should be handled before this"
))?,
Req::Fail(Failure::Crash) => {
tracing::error!("Crashing on purpose");
// SAFETY: Crashing is unsafe
unsafe { sadness_generator::raise_segfault() }
},
Req::Fail(Failure::Error) => Err(anyhow!("Test error"))?,
Req::Fail(Failure::Panic) => panic!("Test panic"),
Req::GetAdvancedSettings(tx) => {
tx.send(self.advanced_settings.clone()).ok();
}
@@ -482,9 +501,12 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
self.integration.open_url(&url)
.context("Couldn't open URL from system tray")?
}
Req::SystemTrayMenu(TrayMenuEvent::Quit) => Err(anyhow!(
"Impossible error: `Quit` should be handled before this"
))?,
Req::SystemTrayMenu(TrayMenuEvent::Quit) => {
tracing::info!("User clicked Quit in the menu");
self.status = Status::Quitting;
self.ipc_client.send_msg(&IpcClientMsg::Disconnect).await?;
self.refresh_system_tray_menu()?;
}
Req::UpdateNotificationClicked(download_url) => {
tracing::info!("UpdateNotificationClicked in run_controller!");
self.integration.open_url(&download_url)
@@ -502,16 +524,18 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
Err(Error::ConnectToFirezoneFailed(error)) => {
tracing::error!("Failed to connect to Firezone: {error}");
self.sign_out().await?;
Ok(ControlFlow::Continue(()))
}
Err(error) => Err(error)?,
Err(error) => Err(error),
},
ipc::Event::ReadFailed(error) => {
// IPC errors are always fatal
tracing::error!(error = anyhow_dyn_err(&error), "IPC read failure");
Err(Error::IpcRead)?
Err(Error::IpcRead)
}
ipc::Event::Closed => Err(Error::IpcClosed)?,
ipc::Event::Closed => Err(Error::IpcClosed),
}
}
@@ -778,3 +802,31 @@ impl<'a, I: GuiIntegration> Controller<'a, I> {
Ok(())
}
}
async fn new_dns_notifier() -> Result<impl Stream<Item = Result<()>>> {
let worker = firezone_bin_shared::new_dns_notifier(
tokio::runtime::Handle::current(),
DnsControlMethod::default(),
)
.await?;
Ok(stream::try_unfold(worker, |mut worker| async move {
let () = worker.notified().await?;
Ok(Some(((), worker)))
}))
}
async fn new_network_notifier() -> Result<impl Stream<Item = Result<()>>> {
let worker = firezone_bin_shared::new_network_notifier(
tokio::runtime::Handle::current(),
DnsControlMethod::default(),
)
.await?;
Ok(stream::try_unfold(worker, |mut worker| async move {
let () = worker.notified().await?;
Ok(Some(((), worker)))
}))
}

View File

@@ -11,7 +11,7 @@ use anyhow::{bail, Context, Result};
use common::system_tray::Event as TrayMenuEvent;
use firezone_gui_client_common::{
self as common,
controller::{ControllerRequest, CtlrTx, GuiIntegration},
controller::{Controller, ControllerRequest, CtlrTx, GuiIntegration},
deep_link,
errors::{self, Error},
settings::AdvancedSettings,
@@ -236,7 +236,7 @@ pub(crate) fn run(
let app_handle = app.handle().clone();
let _ctlr_task = tokio::spawn(async move {
let result = AssertUnwindSafe(run_controller(
let result = AssertUnwindSafe(Controller::start(
ctlr_tx,
integration,
ctlr_rx,
@@ -467,34 +467,3 @@ fn handle_system_tray_event(app: &tauri::AppHandle, event: TrayMenuEvent) -> Res
.blocking_send(ControllerRequest::SystemTrayMenu(event))?;
Ok(())
}
// TODO: Move this into `impl Controller`
async fn run_controller(
ctlr_tx: CtlrTx,
integration: TauriIntegration,
rx: mpsc::Receiver<ControllerRequest>,
advanced_settings: AdvancedSettings,
log_filter_reloader: LogFilterReloader,
telemetry: &mut telemetry::Telemetry,
updates_rx: mpsc::Receiver<Option<updates::Notification>>,
) -> Result<(), Error> {
tracing::debug!("Entered `run_controller`");
let controller = firezone_gui_client_common::controller::Builder {
advanced_settings,
ctlr_tx,
integration,
log_filter_reloader,
rx,
telemetry,
updates_rx,
}
.build()
.await?;
controller.main_loop().await?;
// Last chance to do any drops / cleanup before the process crashes.
Ok(())
}

View File

@@ -14,7 +14,7 @@ use firezone_bin_shared::{
use firezone_headless_client::{
device_id, signals, CallbackHandler, CliCommon, ConnlibMsg, DnsController,
};
use firezone_logging::{anyhow_dyn_err, telemetry_span};
use firezone_logging::telemetry_span;
use firezone_telemetry::Telemetry;
use futures::StreamExt as _;
use phoenix_channel::get_user_agent;
@@ -316,13 +316,6 @@ fn main() -> Result<()> {
}
};
if let Err(error) = dns_notifier.close() {
tracing::error!(error = anyhow_dyn_err(&error), "DNS notifier")
}
if let Err(error) = network_notifier.close() {
tracing::error!(error = anyhow_dyn_err(&error), "network notifier");
}
telemetry.stop().await; // Stop telemetry before dropping session. `connlib` needs to be active for this, otherwise we won't be able to resolve the DNS name for sentry.
session.disconnect();