From 58f1b357d9e8b365805bb0b33382b6ddf5842b35 Mon Sep 17 00:00:00 2001 From: Reactor Scram Date: Wed, 20 Mar 2024 15:11:14 -0500 Subject: [PATCH] refactor(windows): fix cancellation-safety when listening for network events (#4227) Not sure if this is a fix or a refactor. Closes #4226 --- rust/gui-client/src-tauri/Cargo.toml | 2 +- .../src/client/network_changes/windows.rs | 63 ++++++++++--------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/rust/gui-client/src-tauri/Cargo.toml b/rust/gui-client/src-tauri/Cargo.toml index 8eb22342e..aa9a81b74 100644 --- a/rust/gui-client/src-tauri/Cargo.toml +++ b/rust/gui-client/src-tauri/Cargo.toml @@ -44,7 +44,7 @@ tauri = { version = "1.6", features = [ "dialog", "shell-open-api", "system-tray tauri-runtime = "0.14.2" tauri-utils = "1.5.3" thiserror = { version = "1.0", default-features = false } -tokio = { version = "1.36.0", features = ["time"] } +tokio = { version = "1.36.0", features = ["signal", "time"] } tracing = { workspace = true } tracing-log = "0.2" tracing-panic = "0.1.1" diff --git a/rust/gui-client/src-tauri/src/client/network_changes/windows.rs b/rust/gui-client/src-tauri/src/client/network_changes/windows.rs index 990d6d18b..6486a0716 100644 --- a/rust/gui-client/src-tauri/src/client/network_changes/windows.rs +++ b/rust/gui-client/src-tauri/src/client/network_changes/windows.rs @@ -57,8 +57,7 @@ //! Raymond Chen also explains it on his blog: use anyhow::Result; -use std::sync::Arc; -use tokio::{runtime::Runtime, sync::Notify}; +use tokio::{runtime::Runtime, sync::mpsc}; use windows::{ core::{Interface, Result as WinResult, GUID}, Win32::{ @@ -91,7 +90,7 @@ pub(crate) fn run_debug() -> Result<()> { // Returns Err before COM is initialized assert!(get_apartment_type().is_err()); - let com_worker = Worker::new()?; + let mut com_worker = Worker::new()?; // We have to initialize COM again for the main thread. This doesn't // seem to be a problem in the main app since Tauri initializes COM for itself. @@ -108,7 +107,10 @@ pub(crate) fn run_debug() -> Result<()> { rt.block_on(async move { loop { - com_worker.notified().await; + tokio::select! { + _r = tokio::signal::ctrl_c() => break, + () = com_worker.notified() => {}, + }; // Make sure whatever Tokio thread we're on is associated with COM // somehow. assert_eq!( @@ -118,7 +120,10 @@ pub(crate) fn run_debug() -> Result<()> { tracing::info!(have_internet = %check_internet()?); } - }) + Ok::<_, anyhow::Error>(()) + })?; + + Ok(()) } /// Runs a debug subcommand that listens to the registry for DNS changes @@ -188,7 +193,7 @@ pub fn check_internet() -> Result { /// Worker thread that can be joined explicitly, and joins on Drop pub(crate) struct Worker { inner: Option, - notify: Arc, + rx: mpsc::Receiver<()>, } /// Needed so that `Drop` can consume the oneshot Sender and the thread's JoinHandle @@ -199,18 +204,17 @@ struct WorkerInner { impl Worker { pub(crate) fn new() -> Result { - let notify = Arc::new(Notify::new()); + let (tx, rx) = mpsc::channel(1); - let (stopper, rx) = tokio::sync::oneshot::channel(); + let (stopper, stopper_rx) = tokio::sync::oneshot::channel(); let thread = { - let notify = Arc::clone(¬ify); std::thread::Builder::new() .name("Firezone COM worker".into()) .spawn(move || { { let com = ComGuard::new()?; - let _network_change_listener = Listener::new(&com, notify)?; - rx.blocking_recv().ok(); + let _network_change_listener = Listener::new(&com, tx)?; + stopper_rx.blocking_recv().ok(); } tracing::debug!("COM worker thread shut down gracefully"); Ok(()) @@ -219,7 +223,7 @@ impl Worker { Ok(Self { inner: Some(WorkerInner { thread, stopper }), - notify, + rx, }) } @@ -238,8 +242,8 @@ impl Worker { Ok(()) } - pub(crate) async fn notified(&self) { - self.notify.notified().await; + pub(crate) async fn notified(&mut self) { + self.rx.recv().await; } } @@ -302,19 +306,14 @@ struct Listener<'a> { advise_cookie_net: Option, cxn_point_net: Com::IConnectionPoint, - inner: ListenerInner, - /// Hold a reference to a `ComGuard` to enforce the right init-use-uninit order _com: &'a ComGuard, } -/// This must be separate because we need to `Clone` that `Notify` and we can't -/// `Clone` the COM objects in `Listener` // https://kennykerr.ca/rust-getting-started/how-to-implement-com-interface.html #[windows_implement::implement(INetworkEvents)] -#[derive(Clone)] -struct ListenerInner { - notify: Arc, +struct Callback { + tx: mpsc::Sender<()>, } impl<'a> Drop for Listener<'a> { @@ -330,9 +329,9 @@ impl<'a> Listener<'a> { /// /// * `com` - Makes sure that CoInitializeEx was called. `com` have been created /// on the same thread as `new` is called on. - /// * `notify` - A Tokio `Notify` that will be notified when Windows detects + /// * `tx` - A Sender to notify when Windows detects /// connectivity changes. Some notifications may be spurious. - fn new(com: &'a ComGuard, notify: Arc) -> Result { + fn new(com: &'a ComGuard, tx: mpsc::Sender<()>) -> Result { // `windows-rs` automatically releases (de-refs) COM objects on Drop: // https://github.com/microsoft/windows-rs/issues/2123#issuecomment-1293194755 // https://github.com/microsoft/windows-rs/blob/cefdabd15e4a7a7f71b7a2d8b12d5dc148c99adb/crates/samples/windows/wmi/src/main.rs#L22 @@ -349,11 +348,12 @@ impl<'a> Listener<'a> { let mut this = Listener { advise_cookie_net: None, cxn_point_net, - inner: ListenerInner { notify }, _com: com, }; - let callbacks: INetworkEvents = this.inner.clone().into(); + let cb = Callback { tx: tx.clone() }; + + let callbacks: INetworkEvents = cb.into(); // SAFETY: What happens if Windows sends us a network change event while // we're dropping Listener? @@ -367,7 +367,7 @@ impl<'a> Listener<'a> { // 2. Caller continues setup, checks Internet is connected // 3. Internet gets disconnected but caller isn't notified // 4. Worker thread finally gets scheduled, but we never notify that the Internet was lost during setup. Caller is now out of sync with ground truth. - this.inner.notify.notify_one(); + tx.try_send(()).ok(); Ok(this) } @@ -385,7 +385,7 @@ impl<'a> Listener<'a> { } } -impl INetworkEvents_Impl for ListenerInner { +impl INetworkEvents_Impl for Callback { fn NetworkAdded(&self, _networkid: &GUID) -> WinResult<()> { Ok(()) } @@ -399,7 +399,8 @@ impl INetworkEvents_Impl for ListenerInner { _networkid: &GUID, _newconnectivity: NLM_CONNECTIVITY, ) -> WinResult<()> { - self.notify.notify_one(); + // Use `try_send` because we're only sending a notification to wake up the receiver. + self.tx.try_send(()).ok(); Ok(()) } @@ -412,6 +413,12 @@ impl INetworkEvents_Impl for ListenerInner { } } +impl Drop for Callback { + fn drop(&mut self) { + tracing::debug!("Dropped `network_changes::Callback`"); + } +} + /// Checks what COM apartment the current thread is in. For debugging only. fn get_apartment_type() -> WinResult<(Com::APTTYPE, Com::APTTYPEQUALIFIER)> { let mut apt_type = Com::APTTYPE_CURRENT;