mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(windows): fix cancellation-safety when listening for network events (#4227)
Not sure if this is a fix or a refactor. Closes #4226
This commit is contained in:
@@ -44,7 +44,7 @@ tauri = { version = "1.6", features = [ "dialog", "shell-open-api", "system-tray
|
||||
tauri-runtime = "0.14.2"
|
||||
tauri-utils = "1.5.3"
|
||||
thiserror = { version = "1.0", default-features = false }
|
||||
tokio = { version = "1.36.0", features = ["time"] }
|
||||
tokio = { version = "1.36.0", features = ["signal", "time"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-log = "0.2"
|
||||
tracing-panic = "0.1.1"
|
||||
|
||||
@@ -57,8 +57,7 @@
|
||||
//! Raymond Chen also explains it on his blog: <https://devblogs.microsoft.com/oldnewthing/20191125-00/?p=103135>
|
||||
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use tokio::{runtime::Runtime, sync::Notify};
|
||||
use tokio::{runtime::Runtime, sync::mpsc};
|
||||
use windows::{
|
||||
core::{Interface, Result as WinResult, GUID},
|
||||
Win32::{
|
||||
@@ -91,7 +90,7 @@ pub(crate) fn run_debug() -> Result<()> {
|
||||
// Returns Err before COM is initialized
|
||||
assert!(get_apartment_type().is_err());
|
||||
|
||||
let com_worker = Worker::new()?;
|
||||
let mut com_worker = Worker::new()?;
|
||||
|
||||
// We have to initialize COM again for the main thread. This doesn't
|
||||
// seem to be a problem in the main app since Tauri initializes COM for itself.
|
||||
@@ -108,7 +107,10 @@ pub(crate) fn run_debug() -> Result<()> {
|
||||
|
||||
rt.block_on(async move {
|
||||
loop {
|
||||
com_worker.notified().await;
|
||||
tokio::select! {
|
||||
_r = tokio::signal::ctrl_c() => break,
|
||||
() = com_worker.notified() => {},
|
||||
};
|
||||
// Make sure whatever Tokio thread we're on is associated with COM
|
||||
// somehow.
|
||||
assert_eq!(
|
||||
@@ -118,7 +120,10 @@ pub(crate) fn run_debug() -> Result<()> {
|
||||
|
||||
tracing::info!(have_internet = %check_internet()?);
|
||||
}
|
||||
})
|
||||
Ok::<_, anyhow::Error>(())
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs a debug subcommand that listens to the registry for DNS changes
|
||||
@@ -188,7 +193,7 @@ pub fn check_internet() -> Result<bool> {
|
||||
/// Worker thread that can be joined explicitly, and joins on Drop
|
||||
pub(crate) struct Worker {
|
||||
inner: Option<WorkerInner>,
|
||||
notify: Arc<Notify>,
|
||||
rx: mpsc::Receiver<()>,
|
||||
}
|
||||
|
||||
/// Needed so that `Drop` can consume the oneshot Sender and the thread's JoinHandle
|
||||
@@ -199,18 +204,17 @@ struct WorkerInner {
|
||||
|
||||
impl Worker {
|
||||
pub(crate) fn new() -> Result<Self> {
|
||||
let notify = Arc::new(Notify::new());
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
|
||||
let (stopper, rx) = tokio::sync::oneshot::channel();
|
||||
let (stopper, stopper_rx) = tokio::sync::oneshot::channel();
|
||||
let thread = {
|
||||
let notify = Arc::clone(¬ify);
|
||||
std::thread::Builder::new()
|
||||
.name("Firezone COM worker".into())
|
||||
.spawn(move || {
|
||||
{
|
||||
let com = ComGuard::new()?;
|
||||
let _network_change_listener = Listener::new(&com, notify)?;
|
||||
rx.blocking_recv().ok();
|
||||
let _network_change_listener = Listener::new(&com, tx)?;
|
||||
stopper_rx.blocking_recv().ok();
|
||||
}
|
||||
tracing::debug!("COM worker thread shut down gracefully");
|
||||
Ok(())
|
||||
@@ -219,7 +223,7 @@ impl Worker {
|
||||
|
||||
Ok(Self {
|
||||
inner: Some(WorkerInner { thread, stopper }),
|
||||
notify,
|
||||
rx,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -238,8 +242,8 @@ impl Worker {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn notified(&self) {
|
||||
self.notify.notified().await;
|
||||
pub(crate) async fn notified(&mut self) {
|
||||
self.rx.recv().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,19 +306,14 @@ struct Listener<'a> {
|
||||
advise_cookie_net: Option<u32>,
|
||||
cxn_point_net: Com::IConnectionPoint,
|
||||
|
||||
inner: ListenerInner,
|
||||
|
||||
/// Hold a reference to a `ComGuard` to enforce the right init-use-uninit order
|
||||
_com: &'a ComGuard,
|
||||
}
|
||||
|
||||
/// This must be separate because we need to `Clone` that `Notify` and we can't
|
||||
/// `Clone` the COM objects in `Listener`
|
||||
// https://kennykerr.ca/rust-getting-started/how-to-implement-com-interface.html
|
||||
#[windows_implement::implement(INetworkEvents)]
|
||||
#[derive(Clone)]
|
||||
struct ListenerInner {
|
||||
notify: Arc<Notify>,
|
||||
struct Callback {
|
||||
tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl<'a> Drop for Listener<'a> {
|
||||
@@ -330,9 +329,9 @@ impl<'a> Listener<'a> {
|
||||
///
|
||||
/// * `com` - Makes sure that CoInitializeEx was called. `com` have been created
|
||||
/// on the same thread as `new` is called on.
|
||||
/// * `notify` - A Tokio `Notify` that will be notified when Windows detects
|
||||
/// * `tx` - A Sender to notify when Windows detects
|
||||
/// connectivity changes. Some notifications may be spurious.
|
||||
fn new(com: &'a ComGuard, notify: Arc<Notify>) -> Result<Self, Error> {
|
||||
fn new(com: &'a ComGuard, tx: mpsc::Sender<()>) -> Result<Self, Error> {
|
||||
// `windows-rs` automatically releases (de-refs) COM objects on Drop:
|
||||
// https://github.com/microsoft/windows-rs/issues/2123#issuecomment-1293194755
|
||||
// https://github.com/microsoft/windows-rs/blob/cefdabd15e4a7a7f71b7a2d8b12d5dc148c99adb/crates/samples/windows/wmi/src/main.rs#L22
|
||||
@@ -349,11 +348,12 @@ impl<'a> Listener<'a> {
|
||||
let mut this = Listener {
|
||||
advise_cookie_net: None,
|
||||
cxn_point_net,
|
||||
inner: ListenerInner { notify },
|
||||
_com: com,
|
||||
};
|
||||
|
||||
let callbacks: INetworkEvents = this.inner.clone().into();
|
||||
let cb = Callback { tx: tx.clone() };
|
||||
|
||||
let callbacks: INetworkEvents = cb.into();
|
||||
|
||||
// SAFETY: What happens if Windows sends us a network change event while
|
||||
// we're dropping Listener?
|
||||
@@ -367,7 +367,7 @@ impl<'a> Listener<'a> {
|
||||
// 2. Caller continues setup, checks Internet is connected
|
||||
// 3. Internet gets disconnected but caller isn't notified
|
||||
// 4. Worker thread finally gets scheduled, but we never notify that the Internet was lost during setup. Caller is now out of sync with ground truth.
|
||||
this.inner.notify.notify_one();
|
||||
tx.try_send(()).ok();
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
@@ -385,7 +385,7 @@ impl<'a> Listener<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl INetworkEvents_Impl for ListenerInner {
|
||||
impl INetworkEvents_Impl for Callback {
|
||||
fn NetworkAdded(&self, _networkid: &GUID) -> WinResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
@@ -399,7 +399,8 @@ impl INetworkEvents_Impl for ListenerInner {
|
||||
_networkid: &GUID,
|
||||
_newconnectivity: NLM_CONNECTIVITY,
|
||||
) -> WinResult<()> {
|
||||
self.notify.notify_one();
|
||||
// Use `try_send` because we're only sending a notification to wake up the receiver.
|
||||
self.tx.try_send(()).ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -412,6 +413,12 @@ impl INetworkEvents_Impl for ListenerInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Callback {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!("Dropped `network_changes::Callback`");
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks what COM apartment the current thread is in. For debugging only.
|
||||
fn get_apartment_type() -> WinResult<(Com::APTTYPE, Com::APTTYPEQUALIFIER)> {
|
||||
let mut apt_type = Com::APTTYPE_CURRENT;
|
||||
|
||||
Reference in New Issue
Block a user