feat(headless-client/windows): add DNS change / network change listening to the Headless Client (#6022)

Note that for GUI Clients, listening is still done by the GUI process,
not the IPC service.

Yak shave towards #5846. This allows for faster dev cycles since I won't
have to compile all the GUI stuff.

Some changes in here were extracted from other draft PRs.

Changes:
- Remove `thiserror` that was never matched on
- Don't return the DNS resolvers from the notifier directly, just send a
notification and allow the caller to check the resolvers itself if
needed
- Rename `DnsListener` to `DnsNotifier`
- Rename `Worker` to `NetworkNotifier`
- remove `unwrap_or_default` when getting resolvers. I don't know why
it's there, if there's a good reason then it should be handled inside
the function, not in the caller

```[tasklist]
### Tasks
- [x] Rename `*Listener` to `*Notifier`
- [x] (not needed) ~~Support `/etc/resolv.conf` DNS control method too?~~
```
This commit is contained in:
Reactor Scram
2024-07-25 10:45:22 -05:00
committed by GitHub
parent 801a816f36
commit cc1478adc2
12 changed files with 145 additions and 166 deletions

5
rust/Cargo.lock generated
View File

@@ -1801,6 +1801,9 @@ dependencies = [
"url",
"uuid",
"windows 0.57.0",
"windows-core 0.57.0",
"windows-implement 0.57.0",
"winreg 0.52.0",
"wintun",
]
@@ -1886,8 +1889,6 @@ dependencies = [
"url",
"uuid",
"windows 0.57.0",
"windows-core 0.57.0",
"windows-implement 0.57.0",
"winreg 0.52.0",
"wintun",
"zip",

View File

@@ -32,15 +32,28 @@ libc = "0.2"
known-folders = "1.1.0"
ring = "0.17"
uuid = { version = "1.10.0", features = ["v4"] }
windows-core = "0.57.0"
windows-implement = "0.57.0"
wintun = "0.4.0"
winreg = "0.52.0"
[target.'cfg(windows)'.dependencies.windows]
version = "0.57.0"
features = [
# For implementing COM interfaces
"implement",
"Win32_Foundation",
# For listening for network change events
"Win32_Networking_NetworkListManager",
"Win32_NetworkManagement_IpHelper",
"Win32_NetworkManagement_Ndis",
"Win32_Networking_WinSock",
"Win32_Security",
# COM is needed to listen for network change events
"Win32_System_Com",
# Needed to listen for system DNS changes
"Win32_System_Registry",
"Win32_System_Threading",
]
[lints]

View File

@@ -1,3 +1,4 @@
mod network_changes;
mod tun_device_manager;
#[cfg(target_os = "windows")]
@@ -28,6 +29,9 @@ pub const BUNDLE_ID: &str = "dev.firezone.client";
/// Mark for Firezone sockets to prevent routing loops on Linux.
pub const FIREZONE_MARK: u32 = 0xfd002021;
#[cfg(any(target_os = "linux", target_os = "windows"))]
pub use network_changes::{DnsNotifier, NetworkNotifier};
#[cfg(any(target_os = "linux", target_os = "windows"))]
pub use tun_device_manager::TunDeviceManager;

View File

@@ -13,4 +13,4 @@ mod imp;
#[allow(clippy::unnecessary_wraps)]
mod imp;
pub(crate) use imp::{check_internet, DnsListener, Worker};
pub use imp::{DnsNotifier, NetworkNotifier};

View File

@@ -0,0 +1,46 @@
//! Not implemented for Linux yet
use anyhow::Result;
use tokio::time::Interval;
pub struct NetworkNotifier {}
impl NetworkNotifier {
pub fn new() -> Result<Self> {
Ok(Self {})
}
pub fn close(&mut self) -> Result<()> {
Ok(())
}
/// Not implemented on Linux
///
/// On Windows this returns when we gain or lose Internet.
pub async fn notified(&mut self) {
futures::future::pending().await
}
}
pub struct DnsNotifier {
interval: Interval,
}
impl DnsNotifier {
pub fn new() -> Result<Self> {
Ok(Self {
interval: create_interval(),
})
}
pub async fn notified(&mut self) -> Result<()> {
self.interval.tick().await;
Ok(())
}
}
fn create_interval() -> Interval {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
interval
}

View File

@@ -64,7 +64,7 @@
//!
//! Raymond Chen also explains it on his blog: <https://devblogs.microsoft.com/oldnewthing/20191125-00/?p=103135>
use anyhow::Result;
use anyhow::{anyhow, Context as _, Result};
use tokio::sync::mpsc;
use windows::{
core::{Interface, Result as WinResult, GUID},
@@ -77,55 +77,22 @@ use windows::{
},
};
pub(crate) use async_dns::CombinedListener as DnsListener;
pub use async_dns::DnsNotifier;
#[derive(thiserror::Error, Debug)]
pub(crate) enum Error {
#[error("Couldn't initialize COM: {0}")]
ComInitialize(windows::core::Error),
#[error("Couldn't stop worker thread")]
CouldntStopWorkerThread,
#[error("Couldn't creat NetworkListManager")]
CreateNetworkListManager(windows::core::Error),
#[error("Couldn't start listening to network events: {0}")]
Listening(windows::core::Error),
#[error("Couldn't stop listening to network events: {0}")]
Unadvise(windows::core::Error),
}
/// Returns true if Windows thinks we have Internet access per [IsConnectedToInternet](https://learn.microsoft.com/en-us/windows/win32/api/netlistmgr/nf-netlistmgr-inetworklistmanager-get_isconnectedtointernet)
///
/// Call this when `Listener` notifies you.
pub fn check_internet() -> Result<bool> {
// Retrieving the INetworkListManager takes less than half a millisecond, and this
// makes the lifetimes and Send+Sync much simpler for callers, so just retrieve it
// every single time.
// SAFETY: No lifetime problems. TODO: Could threading be a problem?
// I think in practice we'll never call this from two threads, but what if we did?
// Maybe make it a method on a `!Send + !Sync` guard struct?
let network_list_manager: INetworkListManager =
unsafe { Com::CoCreateInstance(&NetworkListManager, None, Com::CLSCTX_ALL) }?;
// SAFETY: `network_list_manager` isn't shared between threads, and the lifetime
// should be good.
let have_internet = unsafe { network_list_manager.IsConnectedToInternet() }?.as_bool();
Ok(have_internet)
}
/// Worker thread that can be joined explicitly, and joins on Drop
pub(crate) struct Worker {
/// Notifies when we change Wi-Fi networks, change between Wi-Fi and Ethernet, or gain / lose Internet
pub struct NetworkNotifier {
inner: Option<WorkerInner>,
rx: mpsc::Receiver<()>,
}
/// Needed so that `Drop` can consume the oneshot Sender and the thread's JoinHandle
struct WorkerInner {
thread: std::thread::JoinHandle<Result<(), Error>>,
thread: std::thread::JoinHandle<Result<()>>,
stopper: tokio::sync::oneshot::Sender<()>,
}
impl Worker {
pub(crate) fn new() -> Result<Self> {
impl NetworkNotifier {
pub fn new() -> Result<Self> {
let (tx, rx) = mpsc::channel(1);
let (stopper, stopper_rx) = tokio::sync::oneshot::channel();
@@ -150,12 +117,12 @@ impl Worker {
}
/// Same as `drop`, but you can catch errors
pub(crate) fn close(&mut self) -> Result<()> {
pub fn close(&mut self) -> Result<()> {
if let Some(inner) = self.inner.take() {
inner
.stopper
.send(())
.map_err(|_| Error::CouldntStopWorkerThread)?;
.map_err(|_| anyhow!("Couldn't stop `NetworkNotifier` worker thread"))?;
match inner.thread.join() {
Err(e) => std::panic::resume_unwind(e),
Ok(x) => x?,
@@ -164,12 +131,12 @@ impl Worker {
Ok(())
}
pub(crate) async fn notified(&mut self) {
pub async fn notified(&mut self) {
self.rx.recv().await;
}
}
impl Drop for Worker {
impl Drop for NetworkNotifier {
fn drop(&mut self) {
self.close()
.expect("should be able to close Worker cleanly");
@@ -193,12 +160,12 @@ type PhantomUnsendUnsync = std::marker::PhantomData<*const ()>;
impl ComGuard {
/// Initialize a "Multi-threaded apartment" so that Windows COM stuff
/// can be called, and COM callbacks can work.
pub fn new() -> Result<Self, Error> {
pub fn new() -> Result<Self> {
// SAFETY: Threading shouldn't be a problem since this is meant to initialize
// COM per-thread anyway.
unsafe { Com::CoInitializeEx(None, Com::COINIT_MULTITHREADED) }
.ok()
.map_err(Error::ComInitialize)?;
.context("Failed in `CoInitializeEx`")?;
Ok(Self {
dropped: false,
_unsend_unsync: Default::default(),
@@ -219,7 +186,7 @@ impl Drop for ComGuard {
}
}
/// Listens to network connectivity change eents
/// Listens to network connectivity change events
struct Listener<'a> {
/// The cookies we get back from `Advise`. Can be None if the owner called `close`
///
@@ -253,19 +220,20 @@ impl<'a> Listener<'a> {
/// on the same thread as `new` is called on.
/// * `tx` - A Sender to notify when Windows detects
/// connectivity changes. Some notifications may be spurious.
fn new(com: &'a ComGuard, tx: mpsc::Sender<()>) -> Result<Self, Error> {
fn new(com: &'a ComGuard, tx: mpsc::Sender<()>) -> Result<Self> {
// `windows-rs` automatically releases (de-refs) COM objects on Drop:
// https://github.com/microsoft/windows-rs/issues/2123#issuecomment-1293194755
// https://github.com/microsoft/windows-rs/blob/cefdabd15e4a7a7f71b7a2d8b12d5dc148c99adb/crates/samples/windows/wmi/src/main.rs#L22
// SAFETY: TODO
let network_list_manager: INetworkListManager =
unsafe { Com::CoCreateInstance(&NetworkListManager, None, Com::CLSCTX_ALL) }
.map_err(Error::CreateNetworkListManager)?;
let cpc: Com::IConnectionPointContainer =
network_list_manager.cast().map_err(Error::Listening)?;
.context("Failed in `CoCreateInstance`")?;
let cpc: Com::IConnectionPointContainer = network_list_manager
.cast()
.context("Failed to cast network list manager")?;
// SAFETY: TODO
let cxn_point_net =
unsafe { cpc.FindConnectionPoint(&INetworkEvents::IID) }.map_err(Error::Listening)?;
let cxn_point_net = unsafe { cpc.FindConnectionPoint(&INetworkEvents::IID) }
.context("Failed in `FindConnectionPoint`")?;
let mut this = Listener {
advise_cookie_net: None,
@@ -280,8 +248,10 @@ impl<'a> Listener<'a> {
// SAFETY: What happens if Windows sends us a network change event while
// we're dropping Listener?
// Is it safe to Advise on `this` and then immediately move it?
this.advise_cookie_net =
Some(unsafe { this.cxn_point_net.Advise(&callbacks) }.map_err(Error::Listening)?);
this.advise_cookie_net = Some(
unsafe { this.cxn_point_net.Advise(&callbacks) }
.context("Failed to listen for network event callbacks")?,
);
// After we call `Advise`, notify. This should avoid a problem if this happens:
//
@@ -297,10 +267,11 @@ impl<'a> Listener<'a> {
/// Like `drop`, but you can catch errors
///
/// Unregisters the network change callbacks
pub fn close(&mut self) -> Result<(), Error> {
pub fn close(&mut self) -> Result<()> {
if let Some(cookie) = self.advise_cookie_net.take() {
// SAFETY: I don't see any memory safety issues.
unsafe { self.cxn_point_net.Unadvise(cookie) }.map_err(Error::Unadvise)?;
unsafe { self.cxn_point_net.Unadvise(cookie) }
.context("Failed to unadvise connection point")?;
tracing::debug!("Unadvised");
}
Ok(())
@@ -343,7 +314,8 @@ impl Drop for Callback {
mod async_dns {
use anyhow::{Context, Result};
use std::{ffi::c_void, net::IpAddr, ops::Deref, path::Path};
use futures::FutureExt as _;
use std::{ffi::c_void, ops::Deref, path::Path, pin::pin};
use tokio::sync::mpsc;
use windows::Win32::{
Foundation::{CloseHandle, BOOLEAN, HANDLE, INVALID_HANDLE_VALUE},
@@ -379,13 +351,13 @@ mod async_dns {
))
}
pub(crate) struct CombinedListener {
pub struct DnsNotifier {
listener_4: Listener,
listener_6: Listener,
}
impl CombinedListener {
pub(crate) fn new() -> Result<Self> {
impl DnsNotifier {
pub fn new() -> Result<Self> {
let (key_ipv4, key_ipv6) = open_network_registry_keys()?;
let listener_4 = Listener::new(key_ipv4)?;
let listener_6 = Listener::new(key_ipv6)?;
@@ -396,15 +368,14 @@ mod async_dns {
})
}
pub(crate) async fn notified(&mut self) -> Result<Vec<IpAddr>> {
tokio::select! {
r = self.listener_4.notified() => r?,
r = self.listener_6.notified() => r?,
pub async fn notified(&mut self) -> Result<()> {
let mut fut_4 = pin!(self.listener_4.notified().fuse());
let mut fut_6 = pin!(self.listener_6.notified().fuse());
futures::select! {
r = fut_4 => r?,
r = fut_6 => r?,
}
Ok(
firezone_headless_client::dns_control::system_resolvers_for_gui()
.unwrap_or_default(),
)
Ok(())
}
}

View File

@@ -66,23 +66,13 @@ nix = { version = "0.28.0", features = ["user"] }
[target.'cfg(target_os = "windows")'.dependencies]
tauri-winrt-notification = "0.5.0"
windows-core = "0.57.0"
windows-implement = "0.57.0"
winreg = "0.52.0"
wintun = "0.4.0"
[target.'cfg(target_os = "windows")'.dependencies.windows]
version = "0.57.0"
features = [
# For implementing COM interfaces
"implement",
"Win32_Foundation",
# For listening for network change events
"Win32_Networking_NetworkListManager",
# COM is needed to listen for network change events
"Win32_System_Com",
# Needed to listen for system DNS changes
"Win32_System_Registry",
"Win32_System_Threading",
]

View File

@@ -13,7 +13,6 @@ mod elevation;
mod gui;
mod ipc;
mod logging;
mod network_changes;
mod settings;
mod updates;
mod uptime;

View File

@@ -4,12 +4,13 @@
//! The real macOS Client is in `swift/apple`
use crate::client::{
self, about, deep_link, ipc, logging, network_changes,
self, about, deep_link, ipc, logging,
settings::{self, AdvancedSettings},
Failure,
};
use anyhow::{anyhow, bail, Context, Result};
use connlib_client_shared::callbacks::ResourceDescription;
use firezone_bin_shared::{DnsNotifier, NetworkNotifier};
use firezone_headless_client::IpcServerMsg;
use secrecy::{ExposeSecret, SecretString};
use std::{
@@ -796,32 +797,24 @@ async fn run_controller(
win.show().context("Couldn't show Welcome window")?;
}
let mut have_internet =
network_changes::check_internet().context("Failed initial check for internet")?;
tracing::info!(?have_internet);
let mut com_worker = NetworkNotifier::new().context("Failed to listen for network changes")?;
let mut com_worker =
network_changes::Worker::new().context("Failed to listen for network changes")?;
let mut dns_listener = network_changes::DnsListener::new()?;
let mut dns_listener = DnsNotifier::new()?;
loop {
// TODO: Add `ControllerRequest::NetworkChange` and `DnsChange` and replace
// `tokio::select!` with a `poll_*` function
tokio::select! {
() = com_worker.notified() => {
let new_have_internet = network_changes::check_internet().context("Failed to check for internet")?;
if new_have_internet != have_internet {
have_internet = new_have_internet;
if controller.status.connlib_is_up() {
tracing::debug!("Internet up/down changed, calling `Session::reconnect`");
controller.ipc_client.reconnect().await?;
}
if controller.status.connlib_is_up() {
tracing::debug!("Internet up/down changed, calling `Session::reconnect`");
controller.ipc_client.reconnect().await?;
}
},
resolvers = dns_listener.notified() => {
let resolvers = resolvers?;
result = dns_listener.notified() => {
result?;
if controller.status.connlib_is_up() {
let resolvers = firezone_headless_client::dns_control::system_resolvers_for_gui()?;
tracing::debug!(?resolvers, "New DNS resolvers, calling `Session::set_dns`");
controller.ipc_client.set_dns(resolvers).await?;
}

View File

@@ -1,62 +0,0 @@
//! Not implemented for Linux yet
use anyhow::Result;
use firezone_headless_client::dns_control::system_resolvers_for_gui;
use std::net::IpAddr;
use tokio::time::Interval;
/// TODO: Implement for Linux
pub(crate) fn check_internet() -> Result<bool> {
Ok(true)
}
pub(crate) struct Worker {}
impl Worker {
pub(crate) fn new() -> Result<Self> {
Ok(Self {})
}
pub(crate) fn close(&mut self) -> Result<()> {
Ok(())
}
/// Not implemented on Linux
///
/// On Windows this returns when we gain or lose Internet.
pub(crate) async fn notified(&mut self) {
futures::future::pending().await
}
}
pub(crate) struct DnsListener {
interval: Interval,
last_seen: Vec<IpAddr>,
}
impl DnsListener {
pub(crate) fn new() -> Result<Self> {
Ok(Self {
interval: create_interval(),
last_seen: system_resolvers_for_gui().unwrap_or_default(),
})
}
pub(crate) async fn notified(&mut self) -> Result<Vec<IpAddr>> {
loop {
self.interval.tick().await;
tracing::trace!("Checking for DNS changes");
let new = system_resolvers_for_gui().unwrap_or_default();
if new != self.last_seen {
self.last_seen.clone_from(&new);
return Ok(new);
}
}
}
}
fn create_interval() -> Interval {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
interval
}

View File

@@ -9,7 +9,9 @@ use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session};
use connlib_shared::get_user_agent;
use firezone_bin_shared::{setup_global_subscriber, TunDeviceManager};
use firezone_bin_shared::{
setup_global_subscriber, DnsNotifier, NetworkNotifier, TunDeviceManager,
};
use futures::{FutureExt as _, StreamExt as _};
use phoenix_channel::PhoenixChannel;
use secrecy::{Secret, SecretString};
@@ -196,22 +198,38 @@ pub fn run_only_headless_client() -> Result<()> {
let mut tun_device = TunDeviceManager::new()?;
let mut cb_rx = ReceiverStream::new(cb_rx).fuse();
let mut dns_notifier = DnsNotifier::new()?;
let mut network_notifier = NetworkNotifier::new()?;
let tun = tun_device.make_tun()?;
session.set_tun(Box::new(tun));
// TODO: DNS should be added dynamically
session.set_dns(dns_control::system_resolvers().unwrap_or_default());
loop {
let mut dns_changed = pin!(dns_notifier.notified().fuse());
let mut network_changed = pin!(network_notifier.notified().fuse());
let cb = futures::select! {
() = terminate => {
tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C");
return Ok(());
break;
},
() = hangup => {
tracing::info!("Caught SIGHUP");
session.reconnect();
continue;
},
result = dns_changed => {
result?;
tracing::info!("DNS change, notifying Session");
session.set_dns(dns_control::system_resolvers()?);
continue;
},
() = network_changed => {
tracing::info!("Network change, reconnecting Session");
session.reconnect();
continue;
},
cb = cb_rx.next() => cb.context("cb_rx unexpectedly ran empty")?,
};
@@ -239,11 +257,17 @@ pub fn run_only_headless_client() -> Result<()> {
}
if cli.exit {
tracing::info!("Exiting due to `--exit` CLI flag");
break Ok(());
break;
}
}
}
}
if let Err(error) = network_notifier.close() {
tracing::error!(?error, "network listener");
}
Ok(())
});
session.disconnect();