refactor(windows): fix cancellation-safety when listening for network events (#4227)

Not sure if this is a fix or a refactor.

Closes #4226
This commit is contained in:
Reactor Scram
2024-03-20 15:11:14 -05:00
committed by GitHub
parent e05cbbe0a0
commit 58f1b357d9
2 changed files with 36 additions and 29 deletions

View File

@@ -44,7 +44,7 @@ tauri = { version = "1.6", features = [ "dialog", "shell-open-api", "system-tray
tauri-runtime = "0.14.2"
tauri-utils = "1.5.3"
thiserror = { version = "1.0", default-features = false }
tokio = { version = "1.36.0", features = ["time"] }
tokio = { version = "1.36.0", features = ["signal", "time"] }
tracing = { workspace = true }
tracing-log = "0.2"
tracing-panic = "0.1.1"

View File

@@ -57,8 +57,7 @@
//! Raymond Chen also explains it on his blog: <https://devblogs.microsoft.com/oldnewthing/20191125-00/?p=103135>
use anyhow::Result;
use std::sync::Arc;
use tokio::{runtime::Runtime, sync::Notify};
use tokio::{runtime::Runtime, sync::mpsc};
use windows::{
core::{Interface, Result as WinResult, GUID},
Win32::{
@@ -91,7 +90,7 @@ pub(crate) fn run_debug() -> Result<()> {
// Returns Err before COM is initialized
assert!(get_apartment_type().is_err());
let com_worker = Worker::new()?;
let mut com_worker = Worker::new()?;
// We have to initialize COM again for the main thread. This doesn't
// seem to be a problem in the main app since Tauri initializes COM for itself.
@@ -108,7 +107,10 @@ pub(crate) fn run_debug() -> Result<()> {
rt.block_on(async move {
loop {
com_worker.notified().await;
tokio::select! {
_r = tokio::signal::ctrl_c() => break,
() = com_worker.notified() => {},
};
// Make sure whatever Tokio thread we're on is associated with COM
// somehow.
assert_eq!(
@@ -118,7 +120,10 @@ pub(crate) fn run_debug() -> Result<()> {
tracing::info!(have_internet = %check_internet()?);
}
})
Ok::<_, anyhow::Error>(())
})?;
Ok(())
}
/// Runs a debug subcommand that listens to the registry for DNS changes
@@ -188,7 +193,7 @@ pub fn check_internet() -> Result<bool> {
/// Worker thread that can be joined explicitly, and joins on Drop
pub(crate) struct Worker {
inner: Option<WorkerInner>,
notify: Arc<Notify>,
rx: mpsc::Receiver<()>,
}
/// Needed so that `Drop` can consume the oneshot Sender and the thread's JoinHandle
@@ -199,18 +204,17 @@ struct WorkerInner {
impl Worker {
pub(crate) fn new() -> Result<Self> {
let notify = Arc::new(Notify::new());
let (tx, rx) = mpsc::channel(1);
let (stopper, rx) = tokio::sync::oneshot::channel();
let (stopper, stopper_rx) = tokio::sync::oneshot::channel();
let thread = {
let notify = Arc::clone(&notify);
std::thread::Builder::new()
.name("Firezone COM worker".into())
.spawn(move || {
{
let com = ComGuard::new()?;
let _network_change_listener = Listener::new(&com, notify)?;
rx.blocking_recv().ok();
let _network_change_listener = Listener::new(&com, tx)?;
stopper_rx.blocking_recv().ok();
}
tracing::debug!("COM worker thread shut down gracefully");
Ok(())
@@ -219,7 +223,7 @@ impl Worker {
Ok(Self {
inner: Some(WorkerInner { thread, stopper }),
notify,
rx,
})
}
@@ -238,8 +242,8 @@ impl Worker {
Ok(())
}
pub(crate) async fn notified(&self) {
self.notify.notified().await;
pub(crate) async fn notified(&mut self) {
self.rx.recv().await;
}
}
@@ -302,19 +306,14 @@ struct Listener<'a> {
advise_cookie_net: Option<u32>,
cxn_point_net: Com::IConnectionPoint,
inner: ListenerInner,
/// Hold a reference to a `ComGuard` to enforce the right init-use-uninit order
_com: &'a ComGuard,
}
/// This must be separate because we need to `Clone` that `Notify` and we can't
/// `Clone` the COM objects in `Listener`
// https://kennykerr.ca/rust-getting-started/how-to-implement-com-interface.html
#[windows_implement::implement(INetworkEvents)]
#[derive(Clone)]
struct ListenerInner {
notify: Arc<Notify>,
struct Callback {
tx: mpsc::Sender<()>,
}
impl<'a> Drop for Listener<'a> {
@@ -330,9 +329,9 @@ impl<'a> Listener<'a> {
///
/// * `com` - Makes sure that CoInitializeEx was called. `com` have been created
/// on the same thread as `new` is called on.
/// * `notify` - A Tokio `Notify` that will be notified when Windows detects
/// * `tx` - A Sender to notify when Windows detects
/// connectivity changes. Some notifications may be spurious.
fn new(com: &'a ComGuard, notify: Arc<Notify>) -> Result<Self, Error> {
fn new(com: &'a ComGuard, tx: mpsc::Sender<()>) -> Result<Self, Error> {
// `windows-rs` automatically releases (de-refs) COM objects on Drop:
// https://github.com/microsoft/windows-rs/issues/2123#issuecomment-1293194755
// https://github.com/microsoft/windows-rs/blob/cefdabd15e4a7a7f71b7a2d8b12d5dc148c99adb/crates/samples/windows/wmi/src/main.rs#L22
@@ -349,11 +348,12 @@ impl<'a> Listener<'a> {
let mut this = Listener {
advise_cookie_net: None,
cxn_point_net,
inner: ListenerInner { notify },
_com: com,
};
let callbacks: INetworkEvents = this.inner.clone().into();
let cb = Callback { tx: tx.clone() };
let callbacks: INetworkEvents = cb.into();
// SAFETY: What happens if Windows sends us a network change event while
// we're dropping Listener?
@@ -367,7 +367,7 @@ impl<'a> Listener<'a> {
// 2. Caller continues setup, checks Internet is connected
// 3. Internet gets disconnected but caller isn't notified
// 4. Worker thread finally gets scheduled, but we never notify that the Internet was lost during setup. Caller is now out of sync with ground truth.
this.inner.notify.notify_one();
tx.try_send(()).ok();
Ok(this)
}
@@ -385,7 +385,7 @@ impl<'a> Listener<'a> {
}
}
impl INetworkEvents_Impl for ListenerInner {
impl INetworkEvents_Impl for Callback {
fn NetworkAdded(&self, _networkid: &GUID) -> WinResult<()> {
Ok(())
}
@@ -399,7 +399,8 @@ impl INetworkEvents_Impl for ListenerInner {
_networkid: &GUID,
_newconnectivity: NLM_CONNECTIVITY,
) -> WinResult<()> {
self.notify.notify_one();
// Use `try_send` because we're only sending a notification to wake up the receiver.
self.tx.try_send(()).ok();
Ok(())
}
@@ -412,6 +413,12 @@ impl INetworkEvents_Impl for ListenerInner {
}
}
impl Drop for Callback {
fn drop(&mut self) {
tracing::debug!("Dropped `network_changes::Callback`");
}
}
/// Checks what COM apartment the current thread is in. For debugging only.
fn get_apartment_type() -> WinResult<(Com::APTTYPE, Com::APTTYPEQUALIFIER)> {
let mut apt_type = Com::APTTYPE_CURRENT;