diff --git a/rust/headless-client/src/ipc_service/windows.rs b/rust/headless-client/src/ipc_service/windows.rs index 87b676cdd..d9c6ca98e 100644 --- a/rust/headless-client/src/ipc_service/windows.rs +++ b/rust/headless-client/src/ipc_service/windows.rs @@ -2,14 +2,12 @@ use crate::CliCommon; use anyhow::{bail, Context as _, Result}; use firezone_bin_shared::platform::DnsControlMethod; use firezone_telemetry::Telemetry; -use futures::future::{self, Either}; +use futures::channel::mpsc; use std::{ ffi::{c_void, OsString}, mem::size_of, - pin::pin, time::Duration, }; -use tokio::sync::mpsc; use windows::{ core::PWSTR, Win32::{ @@ -218,7 +216,7 @@ fn fallible_service_run( let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let (shutdown_tx, shutdown_rx) = mpsc::channel(1); + let (mut shutdown_tx, shutdown_rx) = mpsc::channel(1); let event_handler = move |control_event| -> ServiceControlHandlerResult { tracing::debug!(?control_event); @@ -230,7 +228,7 @@ fn fallible_service_run( ServiceControlHandlerResult::NoError } ServiceControl::Shutdown | ServiceControl::Stop => { - if shutdown_tx.blocking_send(()).is_err() { + if shutdown_tx.try_send(()).is_err() { tracing::error!("Should be able to send shutdown signal"); } ServiceControlHandlerResult::NoError @@ -325,28 +323,21 @@ fn fallible_service_run( async fn service_run_async( log_filter_reloader: &crate::LogFilterReloader, telemetry: &mut Telemetry, - mut shutdown_rx: mpsc::Receiver<()>, + shutdown_rx: mpsc::Receiver<()>, ) -> Result<()> { // Useless - Windows will never send us Ctrl+C when running as a service // This just keeps the signatures simpler - let mut signals = crate::signals::Terminate::new()?; - let listen_fut = pin!(super::ipc_listen( + let mut signals = crate::signals::Terminate::from_channel(shutdown_rx); + super::ipc_listen( DnsControlMethod::Nrpt, log_filter_reloader, &mut signals, - telemetry - )); - match future::select(listen_fut, pin!(shutdown_rx.recv())).await { - Either::Left((Err(error), _)) => Err(error).context("`ipc_listen` threw an error"), - Either::Left((Ok(()), _)) => { - bail!("Impossible - Shouldn't catch Ctrl+C when running as a Windows service") - } - Either::Right((None, _)) => bail!("Shutdown channel failed"), - Either::Right((Some(()), _)) => { - tracing::info!("Caught shutdown signal, stopping IPC listener"); - Ok(()) - } - } + telemetry, + ) + .await + .context("`ipc_listen` threw an error")?; + + Ok(()) } #[cfg(test)] diff --git a/rust/headless-client/src/signals/windows.rs b/rust/headless-client/src/signals/windows.rs index 6256ea439..5c1d01560 100644 --- a/rust/headless-client/src/signals/windows.rs +++ b/rust/headless-client/src/signals/windows.rs @@ -1,13 +1,14 @@ use anyhow::Result; use futures::{ + channel::mpsc, future::poll_fn, + stream::BoxStream, task::{Context, Poll}, + StreamExt as _, }; -// This looks like a pointless wrapper around `CtrlC`, because it must match -// the Linux signatures pub struct Terminate { - sigint: tokio::signal::windows::CtrlC, + inner: BoxStream<'static, ()>, } // SIGHUP is used on Linux but not on Windows @@ -16,11 +17,22 @@ pub struct Hangup {} impl Terminate { pub fn new() -> Result { let sigint = tokio::signal::windows::ctrl_c()?; - Ok(Self { sigint }) + let inner = futures::stream::unfold(sigint, |mut sigint| async move { + sigint.recv().await?; + + Some(((), sigint)) + }) + .boxed(); + + Ok(Self { inner }) + } + + pub fn from_channel(rx: mpsc::Receiver<()>) -> Self { + Self { inner: rx.boxed() } } pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<()> { - self.sigint.poll_recv(cx).map(|_| ()) + self.inner.poll_next_unpin(cx).map(|_| ()) } /// Waits for Ctrl+C