diff --git a/rust/headless-client/src/ipc_service/windows.rs b/rust/headless-client/src/ipc_service/windows.rs index c5d2efb65..e32416503 100644 --- a/rust/headless-client/src/ipc_service/windows.rs +++ b/rust/headless-client/src/ipc_service/windows.rs @@ -1,7 +1,9 @@ use crate::CliCommon; -use anyhow::{Context as _, Result}; +use anyhow::{bail, Context as _, Result}; use connlib_client_shared::file_logger; -use std::{ffi::OsString, time::Duration}; +use futures::future::{self, Either}; +use std::{ffi::OsString, pin::pin, time::Duration}; +use tokio::sync::mpsc; use windows_service::{ service::{ ServiceAccess, ServiceControl, ServiceControlAccept, ServiceErrorControl, ServiceExitCode, @@ -53,13 +55,13 @@ pub(crate) fn run_ipc_service(_cli: CliCommon) -> Result<()> { } // Generates `ffi_service_run` from `service_run` -windows_service::define_windows_service!(ffi_service_run, windows_service_run); +windows_service::define_windows_service!(ffi_service_run, service_run); -fn windows_service_run(arguments: Vec) { +fn service_run(arguments: Vec) { // `arguments` doesn't seem to work right when running as a Windows service // (even though it's meant for that) so just use the default log dir. let handle = super::setup_logging(None).expect("Should be able to set up logging"); - if let Err(error) = fallible_windows_service_run(arguments, handle) { + if let Err(error) = fallible_service_run(arguments, handle) { tracing::error!(?error, "`fallible_windows_service_run` returned an error"); } } @@ -69,16 +71,14 @@ fn windows_service_run(arguments: Vec) { // The arguments don't seem to match the ones passed to the main thread at all. // // If Windows stops us gracefully, this function may never return. -fn fallible_windows_service_run( +fn fallible_service_run( arguments: Vec, logging_handle: file_logger::Handle, ) -> Result<()> { tracing::info!(?arguments, "fallible_windows_service_run"); let rt = tokio::runtime::Runtime::new()?; - - let ipc_task = rt.spawn(super::ipc_listen()); - let ipc_task_ah = ipc_task.abort_handle(); + let (shutdown_tx, shutdown_rx) = mpsc::channel(1); let event_handler = move |control_event| -> ServiceControlHandlerResult { tracing::debug!(?control_event); @@ -90,8 +90,9 @@ fn fallible_windows_service_run( ServiceControlHandlerResult::NoError } ServiceControl::Shutdown | ServiceControl::Stop => { - tracing::info!(?control_event, "Got stop signal from service controller"); - ipc_task_ah.abort(); + if shutdown_tx.blocking_send(()).is_err() { + tracing::error!("Should be able to send shutdown signal"); + } ServiceControlHandlerResult::NoError } ServiceControl::UserEvent(_) => ServiceControlHandlerResult::NoError, @@ -128,17 +129,12 @@ fn fallible_windows_service_run( process_id: None, })?; - let result = match rt.block_on(ipc_task) { - Err(join_error) if join_error.is_cancelled() => { - // We cancelled because Windows asked us to shut down. - Ok(()) - } - Err(join_error) => Err(anyhow::Error::from(join_error).context("`ipc_listen` panicked")), - Ok(Err(error)) => Err(error.context("`ipc_listen` threw an error")), - Ok(Ok(impossible)) => match impossible {}, - }; + // Add new features in `service_run_async` if possible. + // We don't want to bail out of `fallible_service_run` and forget to tell + // Windows that we're shutting down. + let result = rt.block_on(service_run_async(shutdown_rx)); if let Err(error) = &result { - tracing::error!(?error, "`ipc_listen` failed"); + tracing::error!(?error); } // Drop the logging handle so it flushes the logs before we let Windows kill our process. @@ -162,7 +158,28 @@ fn fallible_windows_service_run( wait_hint: Duration::default(), process_id: None, }) - .expect("Should be able to tell Windows we're stopping"); - // Generally unreachable + .context("Should be able to tell Windows we're stopping")?; + // Generally unreachable. Windows typically kills the process first, + // but doesn't guarantee it. Ok(()) } + +/// The main loop for the Windows service +/// +/// This is split off from other functions because we don't want to accidentally +/// bail out of a fallible function and not tell Windows that we're stopping +/// the service. So it's okay to bail out of `service_run_async`, but not +/// out of its caller. +/// +/// Logging must already be set up before calling this. +async fn service_run_async(mut shutdown_rx: mpsc::Receiver<()>) -> Result<()> { + match future::select(pin!(super::ipc_listen()), pin!(shutdown_rx.recv())).await { + Either::Left((Err(error), _)) => Err(error).context("`ipc_listen` threw an error"), + Either::Left((Ok(impossible), _)) => match impossible {}, + Either::Right((None, _)) => bail!("Shutdown channel failed"), + Either::Right((Some(()), _)) => { + tracing::info!("Caught shutdown signal, stopping IPC listener"); + Ok(()) + } + } +}