refactor(ipc-service/windows): remove unnecessary tokio::spawn (#5813)

This also improves some function names (i.e. don't say `windows_` when
we're already in `windows.rs`) and adds comments justifying why some
functions with only one call site are split out

I started this intending to use it to practice the sans-I/O style. It
didn't come up but I did get rid of that `spawn`
This commit is contained in:
Reactor Scram
2024-07-11 14:17:55 +00:00
committed by GitHub
parent 8ec6a809a1
commit cb2bddae7e

View File

@@ -1,7 +1,9 @@
use crate::CliCommon;
use anyhow::{Context as _, Result};
use anyhow::{bail, Context as _, Result};
use connlib_client_shared::file_logger;
use std::{ffi::OsString, time::Duration};
use futures::future::{self, Either};
use std::{ffi::OsString, pin::pin, time::Duration};
use tokio::sync::mpsc;
use windows_service::{
service::{
ServiceAccess, ServiceControl, ServiceControlAccept, ServiceErrorControl, ServiceExitCode,
@@ -53,13 +55,13 @@ pub(crate) fn run_ipc_service(_cli: CliCommon) -> Result<()> {
}
// Generates `ffi_service_run` from `service_run`
windows_service::define_windows_service!(ffi_service_run, windows_service_run);
windows_service::define_windows_service!(ffi_service_run, service_run);
fn windows_service_run(arguments: Vec<OsString>) {
fn service_run(arguments: Vec<OsString>) {
// `arguments` doesn't seem to work right when running as a Windows service
// (even though it's meant for that) so just use the default log dir.
let handle = super::setup_logging(None).expect("Should be able to set up logging");
if let Err(error) = fallible_windows_service_run(arguments, handle) {
if let Err(error) = fallible_service_run(arguments, handle) {
tracing::error!(?error, "`fallible_windows_service_run` returned an error");
}
}
@@ -69,16 +71,14 @@ fn windows_service_run(arguments: Vec<OsString>) {
// The arguments don't seem to match the ones passed to the main thread at all.
//
// If Windows stops us gracefully, this function may never return.
fn fallible_windows_service_run(
fn fallible_service_run(
arguments: Vec<OsString>,
logging_handle: file_logger::Handle,
) -> Result<()> {
tracing::info!(?arguments, "fallible_windows_service_run");
let rt = tokio::runtime::Runtime::new()?;
let ipc_task = rt.spawn(super::ipc_listen());
let ipc_task_ah = ipc_task.abort_handle();
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let event_handler = move |control_event| -> ServiceControlHandlerResult {
tracing::debug!(?control_event);
@@ -90,8 +90,9 @@ fn fallible_windows_service_run(
ServiceControlHandlerResult::NoError
}
ServiceControl::Shutdown | ServiceControl::Stop => {
tracing::info!(?control_event, "Got stop signal from service controller");
ipc_task_ah.abort();
if shutdown_tx.blocking_send(()).is_err() {
tracing::error!("Should be able to send shutdown signal");
}
ServiceControlHandlerResult::NoError
}
ServiceControl::UserEvent(_) => ServiceControlHandlerResult::NoError,
@@ -128,17 +129,12 @@ fn fallible_windows_service_run(
process_id: None,
})?;
let result = match rt.block_on(ipc_task) {
Err(join_error) if join_error.is_cancelled() => {
// We cancelled because Windows asked us to shut down.
Ok(())
}
Err(join_error) => Err(anyhow::Error::from(join_error).context("`ipc_listen` panicked")),
Ok(Err(error)) => Err(error.context("`ipc_listen` threw an error")),
Ok(Ok(impossible)) => match impossible {},
};
// Add new features in `service_run_async` if possible.
// We don't want to bail out of `fallible_service_run` and forget to tell
// Windows that we're shutting down.
let result = rt.block_on(service_run_async(shutdown_rx));
if let Err(error) = &result {
tracing::error!(?error, "`ipc_listen` failed");
tracing::error!(?error);
}
// Drop the logging handle so it flushes the logs before we let Windows kill our process.
@@ -162,7 +158,28 @@ fn fallible_windows_service_run(
wait_hint: Duration::default(),
process_id: None,
})
.expect("Should be able to tell Windows we're stopping");
// Generally unreachable
.context("Should be able to tell Windows we're stopping")?;
// Generally unreachable. Windows typically kills the process first,
// but doesn't guarantee it.
Ok(())
}
/// The main loop for the Windows service
///
/// This is split off from other functions because we don't want to accidentally
/// bail out of a fallible function and not tell Windows that we're stopping
/// the service. So it's okay to bail out of `service_run_async`, but not
/// out of its caller.
///
/// Logging must already be set up before calling this.
async fn service_run_async(mut shutdown_rx: mpsc::Receiver<()>) -> Result<()> {
match future::select(pin!(super::ipc_listen()), pin!(shutdown_rx.recv())).await {
Either::Left((Err(error), _)) => Err(error).context("`ipc_listen` threw an error"),
Either::Left((Ok(impossible), _)) => match impossible {},
Either::Right((None, _)) => bail!("Shutdown channel failed"),
Either::Right((Some(()), _)) => {
tracing::info!("Caught shutdown signal, stopping IPC listener");
Ok(())
}
}
}