diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 4201f61b5..c35a80f2a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1966,6 +1966,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index 1f8c50636..f496554b9 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -25,6 +25,7 @@ thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. tokio = { version = "1.38.0", features = ["macros", "signal"] } +tokio-stream = "0.1.15" tokio-util = { version = "0.7.11", features = ["codec"] } tracing = { workspace = true } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index 63e7342fa..43a348da4 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -1,10 +1,10 @@ use crate::{ device_id, dns_control::{self, DnsController}, - known_dirs, CallbackHandler, CliCommon, InternalServerMsg, IpcServerMsg, SignalKind, Signals, + known_dirs, signals, CallbackHandler, CliCommon, InternalServerMsg, IpcServerMsg, TOKEN_ENV_KEY, }; -use anyhow::{bail, Context as _, Result}; +use anyhow::{Context as _, Result}; use clap::Parser; use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets}; use connlib_shared::tun_device_manager; @@ -98,7 +98,7 @@ fn run_debug_ipc_service() -> Result<()> { ); let rt = tokio::runtime::Runtime::new()?; let _guard = rt.enter(); - let mut signals = Signals::new()?; + let mut signals = signals::Terminate::new()?; rt.block_on(ipc_listen_with_signals(&mut signals)) } @@ -107,19 +107,12 @@ fn run_debug_ipc_service() -> Result<()> { /// /// Shared between the Linux systemd service and the debug subcommand /// TODO: Better name -async fn ipc_listen_with_signals(signals: &mut Signals) -> Result<()> { +async fn ipc_listen_with_signals(signals: &mut signals::Terminate) -> Result<()> { let ipc_service = pin!(ipc_listen()); match future::select(pin!(signals.recv()), ipc_service).await { - future::Either::Left((SignalKind::Hangup, _)) => { - bail!("Exiting, SIGHUP not implemented for the IPC service"); - } - future::Either::Left((SignalKind::Interrupt, _)) => { - tracing::info!("Caught SIGINT"); - Ok(()) - } - future::Either::Left((SignalKind::Terminate, _)) => { - tracing::info!("Caught SIGTERM"); + future::Either::Left(((), _)) => { + tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C"); Ok(()) } future::Either::Right((Ok(impossible), _)) => match impossible {}, diff --git a/rust/headless-client/src/ipc_service/linux.rs b/rust/headless-client/src/ipc_service/linux.rs index 3f62e6239..85fa46d46 100644 --- a/rust/headless-client/src/ipc_service/linux.rs +++ b/rust/headless-client/src/ipc_service/linux.rs @@ -1,5 +1,5 @@ use super::CliCommon; -use crate::Signals; +use crate::signals; use anyhow::{bail, Result}; /// Cross-platform entry point for systemd / Windows services @@ -12,7 +12,7 @@ pub(crate) fn run_ipc_service(cli: CliCommon) -> Result<()> { } let rt = tokio::runtime::Runtime::new()?; let _guard = rt.enter(); - let mut signals = Signals::new()?; + let mut signals = signals::Terminate::new()?; rt.block_on(super::ipc_listen_with_signals(&mut signals)) } diff --git a/rust/headless-client/src/lib.rs b/rust/headless-client/src/lib.rs index 8e1514987..cd9aef98d 100644 --- a/rust/headless-client/src/lib.rs +++ b/rust/headless-client/src/lib.rs @@ -20,10 +20,6 @@ use tracing::subscriber::set_global_default; use tracing_subscriber::{fmt, layer::SubscriberExt as _, EnvFilter, Layer as _, Registry}; use platform::default_token_path; -/// SIGINT and, on Linux, SIGHUP. -/// -/// Must be constructed inside a Tokio runtime context. -use platform::Signals; /// Generate a persistent device ID, stores it to disk, and reads it back. pub(crate) mod device_id; @@ -31,6 +27,7 @@ pub(crate) mod device_id; pub mod dns_control; mod ipc_service; pub mod known_dirs; +mod signals; mod standalone; pub mod uptime; @@ -156,20 +153,6 @@ impl Callbacks for CallbackHandler { } } -#[allow(dead_code)] -enum SignalKind { - /// SIGHUP - /// - /// Not caught on Windows - Hangup, - /// SIGINT - Interrupt, - /// SIGTERM - /// - /// Not caught on Windows - Terminate, -} - /// Sets up logging for stdout only, with INFO level by default pub fn setup_stdout_logging() -> Result<()> { let filter = EnvFilter::new(ipc_service::get_log_filter().context("Can't read log filter")?); diff --git a/rust/headless-client/src/linux.rs b/rust/headless-client/src/linux.rs index 7e2556908..48c925e72 100644 --- a/rust/headless-client/src/linux.rs +++ b/rust/headless-client/src/linux.rs @@ -1,50 +1,14 @@ //! Implementation, Linux-specific -use super::{SignalKind, TOKEN_ENV_KEY}; +use super::TOKEN_ENV_KEY; use anyhow::{bail, Result}; -use futures::future::FutureExt as _; -use std::{ - path::{Path, PathBuf}, - pin::pin, -}; -use tokio::signal::unix::{signal, Signal, SignalKind as TokioSignalKind}; +use std::path::{Path, PathBuf}; // The Client currently must run as root to control DNS // Root group and user are used to check file ownership on the token const ROOT_GROUP: u32 = 0; const ROOT_USER: u32 = 0; -pub(crate) struct Signals { - /// For reloading settings in the standalone Client - sighup: Signal, - /// For Ctrl+C from a terminal - sigint: Signal, - /// For systemd service stopping - sigterm: Signal, -} - -impl Signals { - pub(crate) fn new() -> Result { - let sighup = signal(TokioSignalKind::hangup())?; - let sigint = signal(TokioSignalKind::interrupt())?; - let sigterm = signal(TokioSignalKind::terminate())?; - - Ok(Self { - sighup, - sigint, - sigterm, - }) - } - - pub(crate) async fn recv(&mut self) -> SignalKind { - futures::select! { - _ = pin!(self.sighup.recv().fuse()) => SignalKind::Hangup, - _ = pin!(self.sigint.recv().fuse()) => SignalKind::Interrupt, - _ = pin!(self.sigterm.recv().fuse()) => SignalKind::Terminate, - } - } -} - pub(crate) fn default_token_path() -> PathBuf { PathBuf::from("/etc") .join(connlib_shared::BUNDLE_ID) diff --git a/rust/headless-client/src/signals.rs b/rust/headless-client/src/signals.rs new file mode 100644 index 000000000..e61d4fe73 --- /dev/null +++ b/rust/headless-client/src/signals.rs @@ -0,0 +1,9 @@ +#[cfg(target_os = "linux")] +#[path = "signals/linux.rs"] +mod platform; + +#[cfg(target_os = "windows")] +#[path = "signals/windows.rs"] +mod platform; + +pub(crate) use platform::{Hangup, Terminate}; diff --git a/rust/headless-client/src/signals/linux.rs b/rust/headless-client/src/signals/linux.rs new file mode 100644 index 000000000..bf6ba46eb --- /dev/null +++ b/rust/headless-client/src/signals/linux.rs @@ -0,0 +1,46 @@ +use anyhow::Result; +use futures::FutureExt as _; +use std::pin::pin; +use tokio::signal::unix::{signal, Signal, SignalKind}; + +pub(crate) struct Terminate { + /// For Ctrl+C from a terminal + sigint: Signal, + /// For systemd service stopping + sigterm: Signal, +} + +pub(crate) struct Hangup { + /// For reloading settings in the standalone Client + sighup: Signal, +} + +impl Terminate { + pub(crate) fn new() -> Result { + let sigint = signal(SignalKind::interrupt())?; + let sigterm = signal(SignalKind::terminate())?; + + Ok(Self { sigint, sigterm }) + } + + /// Waits for SIGINT or SIGTERM + pub(crate) async fn recv(&mut self) { + futures::select! { + _ = pin!(self.sigint.recv().fuse()) => {}, + _ = pin!(self.sigterm.recv().fuse()) => {}, + } + } +} + +impl Hangup { + pub(crate) fn new() -> Result { + let sighup = signal(SignalKind::hangup())?; + + Ok(Self { sighup }) + } + + /// Waits for SIGHUP + pub(crate) async fn recv(&mut self) { + self.sighup.recv().await; + } +} diff --git a/rust/headless-client/src/signals/windows.rs b/rust/headless-client/src/signals/windows.rs new file mode 100644 index 000000000..bb3bae3ac --- /dev/null +++ b/rust/headless-client/src/signals/windows.rs @@ -0,0 +1,35 @@ +use anyhow::Result; + +// This looks like a pointless wrapper around `CtrlC`, because it must match +// the Linux signatures +pub(crate) struct Terminate { + sigint: tokio::signal::windows::CtrlC, +} + +// SIGHUP is used on Linux but not on Windows +pub(crate) struct Hangup {} + +impl Terminate { + pub(crate) fn new() -> Result { + let sigint = tokio::signal::windows::ctrl_c()?; + Ok(Self { sigint }) + } + + /// Waits for Ctrl+C + pub(crate) async fn recv(&mut self) { + self.sigint.recv().await; + } +} + +impl Hangup { + #[allow(clippy::unnecessary_wraps)] + pub(crate) fn new() -> Result { + Ok(Self {}) + } + + /// Waits forever - Only implemented for Linux + pub(crate) async fn recv(&mut self) { + let () = std::future::pending().await; + unreachable!() + } +} diff --git a/rust/headless-client/src/standalone.rs b/rust/headless-client/src/standalone.rs index 84082a880..34027ae3b 100644 --- a/rust/headless-client/src/standalone.rs +++ b/rust/headless-client/src/standalone.rs @@ -1,21 +1,22 @@ //! AKA "Headless" use crate::{ - default_token_path, device_id, dns_control, platform, CallbackHandler, CliCommon, - DnsController, InternalServerMsg, IpcServerMsg, SignalKind, Signals, TOKEN_ENV_KEY, + default_token_path, device_id, dns_control, platform, signals, CallbackHandler, CliCommon, + DnsController, InternalServerMsg, IpcServerMsg, TOKEN_ENV_KEY, }; use anyhow::{anyhow, Context as _, Result}; use clap::Parser; use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets}; use connlib_shared::tun_device_manager; use firezone_cli_utils::setup_global_subscriber; -use futures::future; +use futures::{FutureExt as _, StreamExt as _}; use secrecy::SecretString; use std::{ path::{Path, PathBuf}, pin::pin, }; use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; /// Command-line args for the headless Client #[derive(clap::Parser)] @@ -127,7 +128,6 @@ pub fn run_only_headless_client() -> Result<()> { ) })?; tracing::info!("Running in headless / standalone mode"); - let _guard = rt.enter(); // TODO: Should this default to 30 days? let max_partition_time = cli.common.max_partition_time.map(|d| d.into()); @@ -151,7 +151,7 @@ pub fn run_only_headless_client() -> Result<()> { return Ok(()); } - let (cb_tx, mut cb_rx) = mpsc::channel(10); + let (cb_tx, cb_rx) = mpsc::channel(10); let callbacks = CallbackHandler { cb_tx }; platform::setup_before_connlib()?; @@ -170,47 +170,47 @@ pub fn run_only_headless_client() -> Result<()> { platform::notify_service_controller()?; let result = rt.block_on(async { + let mut terminate = signals::Terminate::new()?; + let mut hangup = signals::Hangup::new()?; + let mut terminate = pin!(terminate.recv().fuse()); + let mut hangup = pin!(hangup.recv().fuse()); let mut dns_controller = DnsController::default(); let mut tun_device = tun_device_manager::TunDeviceManager::new()?; - let mut signals = Signals::new()?; + let mut cb_rx = ReceiverStream::new(cb_rx).fuse(); loop { - match future::select(pin!(signals.recv()), pin!(cb_rx.recv())).await { - future::Either::Left((SignalKind::Hangup, _)) => { + let cb = futures::select! { + () = terminate => { + tracing::info!("Caught SIGINT / SIGTERM / Ctrl+C"); + return Ok(()); + }, + () = hangup => { tracing::info!("Caught SIGHUP"); session.reconnect(); - } - future::Either::Left((SignalKind::Interrupt, _)) => { - tracing::info!("Caught SIGINT"); - return Ok(()); - } - future::Either::Left((SignalKind::Terminate, _)) => { - tracing::info!("Caught SIGTERM"); - return Ok(()); - } - future::Either::Right((None, _)) => { - return Err(anyhow::anyhow!("cb_rx unexpectedly ran empty")); - } - future::Either::Right((Some(msg), _)) => match msg { - // TODO: Headless Client shouldn't be using messages labelled `Ipc` - InternalServerMsg::Ipc(IpcServerMsg::OnDisconnect { - error_msg, - is_authentication_error: _, - }) => return Err(anyhow!(error_msg).context("Firezone disconnected")), - InternalServerMsg::Ipc(IpcServerMsg::Ok) - | InternalServerMsg::Ipc(IpcServerMsg::OnTunnelReady) => {} - InternalServerMsg::Ipc(IpcServerMsg::OnUpdateResources(_)) => { - // On every resources update, flush DNS to mitigate - dns_controller.flush()?; - } - InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => { - tun_device.set_ips(ipv4, ipv6).await?; - dns_controller.set_dns(&dns).await?; - } - InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => { - tun_device.set_routes(ipv4, ipv6).await? - } + continue; }, + cb = cb_rx.next() => cb.context("cb_rx unexpectedly ran empty")?, + }; + + match cb { + // TODO: Headless Client shouldn't be using messages labelled `Ipc` + InternalServerMsg::Ipc(IpcServerMsg::OnDisconnect { + error_msg, + is_authentication_error: _, + }) => return Err(anyhow!(error_msg).context("Firezone disconnected")), + InternalServerMsg::Ipc(IpcServerMsg::Ok) + | InternalServerMsg::Ipc(IpcServerMsg::OnTunnelReady) => {} + InternalServerMsg::Ipc(IpcServerMsg::OnUpdateResources(_)) => { + // On every resources update, flush DNS to mitigate + dns_controller.flush()?; + } + InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => { + tun_device.set_ips(ipv4, ipv6).await?; + dns_controller.set_dns(&dns).await?; + } + InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => { + tun_device.set_routes(ipv4, ipv6).await? + } } } }); diff --git a/rust/headless-client/src/windows.rs b/rust/headless-client/src/windows.rs index be607d836..e88c5847e 100644 --- a/rust/headless-client/src/windows.rs +++ b/rust/headless-client/src/windows.rs @@ -4,31 +4,12 @@ //! service to be stopped even if its only process ends, for some reason. //! We must tell Windows explicitly when our service is stopping. -use crate::SignalKind; use anyhow::Result; use std::path::{Path, PathBuf}; #[path = "windows/wintun_install.rs"] mod wintun_install; -// This looks like a pointless wrapper around `CtrlC`, because it must match -// the Linux signatures -pub(crate) struct Signals { - sigint: tokio::signal::windows::CtrlC, -} - -impl Signals { - pub(crate) fn new() -> Result { - let sigint = tokio::signal::windows::ctrl_c()?; - Ok(Self { sigint }) - } - - pub(crate) async fn recv(&mut self) -> SignalKind { - self.sigint.recv().await; - SignalKind::Interrupt - } -} - // The return value is useful on Linux #[allow(clippy::unnecessary_wraps)] pub(crate) fn check_token_permissions(_path: &Path) -> Result<()> {