diff --git a/rust/headless-client/src/dns_control/linux.rs b/rust/headless-client/src/dns_control/linux.rs index 9f9d5cbf5..49b9d10d1 100644 --- a/rust/headless-client/src/dns_control/linux.rs +++ b/rust/headless-client/src/dns_control/linux.rs @@ -29,6 +29,16 @@ pub(crate) struct DnsController { dns_control_method: Option, } +impl Default for DnsController { + fn default() -> Self { + // We'll remove `get_dns_control_from_env` in #5068 + let dns_control_method = get_dns_control_from_env(); + tracing::info!(?dns_control_method); + + Self { dns_control_method } + } +} + impl Drop for DnsController { fn drop(&mut self) { tracing::debug!("Reverting DNS control..."); @@ -40,14 +50,6 @@ impl Drop for DnsController { } impl DnsController { - pub(crate) fn new() -> Self { - // We'll remove `get_dns_control_from_env` in #5068 - let dns_control_method = get_dns_control_from_env(); - tracing::info!(?dns_control_method); - - Self { dns_control_method } - } - /// Set the computer's system-wide DNS servers /// /// The `mut` in `&mut self` is not needed by Rust's rules, but diff --git a/rust/headless-client/src/dns_control/windows.rs b/rust/headless-client/src/dns_control/windows.rs index abc0f2653..19048f83f 100644 --- a/rust/headless-client/src/dns_control/windows.rs +++ b/rust/headless-client/src/dns_control/windows.rs @@ -21,6 +21,7 @@ pub fn system_resolvers_for_gui() -> Result> { system_resolvers() } +#[derive(Default)] pub(crate) struct DnsController {} // Unique magic number that we can use to delete our well-known NRPT rule. @@ -36,10 +37,6 @@ impl Drop for DnsController { } impl DnsController { - pub(crate) fn new() -> Self { - Self {} - } - /// Set the computer's system-wide DNS servers /// /// There's a gap in this because on Windows we deactivate and re-activate control. diff --git a/rust/headless-client/src/lib.rs b/rust/headless-client/src/lib.rs index 914abea89..6ca874644 100644 --- a/rust/headless-client/src/lib.rs +++ b/rust/headless-client/src/lib.rs @@ -21,11 +21,16 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, pin::pin, + time::Duration, +}; +use tokio::{ + io::{ReadHalf, WriteHalf}, + sync::mpsc, + time::Instant, }; -use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use tracing::subscriber::set_global_default; -use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer as _, Registry}; +use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer, Registry}; use url::Url; use platform::default_token_path; @@ -51,6 +56,7 @@ pub mod windows; #[cfg(target_os = "windows")] pub(crate) use windows as platform; +use dns_control::DnsController; use ipc::{Server as IpcServer, Stream as IpcStream}; /// Only used on Linux @@ -298,7 +304,7 @@ pub fn run_only_headless_client() -> Result<()> { platform::notify_service_controller()?; let result = rt.block_on(async { - let mut dns_controller = dns_control::DnsController::new(); + let mut dns_controller = dns_control::DnsController::default(); let mut tun_device = tun_device_manager::TunDeviceManager::new()?; let mut signals = Signals::new()?; @@ -447,49 +453,115 @@ async fn ipc_listen() -> Result { .next_client() .await .context("Failed to wait for incoming IPC connection from a GUI")?; - if let Err(error) = handle_ipc_client(stream).await { - tracing::error!(?error, "Error while handling IPC client"); - } + Handler::new(stream)? + .run() + .await + .context("Error while handling IPC client")?; } } -async fn handle_ipc_client(stream: IpcStream) -> Result<()> { - let (rx, tx) = tokio::io::split(stream); - let mut rx = FramedRead::new(rx, LengthDelimitedCodec::new()); - let mut tx = FramedWrite::new(tx, LengthDelimitedCodec::new()); - let (cb_tx, mut cb_rx) = mpsc::channel(10); +/// Handles one IPC client +struct Handler { + callback_handler: CallbackHandler, + cb_rx: mpsc::Receiver, + connlib: Option, + dns_controller: DnsController, + ipc_rx: FramedRead, LengthDelimitedCodec>, + ipc_tx: FramedWrite, LengthDelimitedCodec>, + last_connlib_start_instant: Option, + tun_device: tun_device_manager::TunDeviceManager, +} - let send_task = tokio::spawn(async move { - let mut dns_controller = dns_control::DnsController::new(); - let mut tun_device = tun_device_manager::TunDeviceManager::new()?; +enum Event { + Callback(InternalServerMsg), + Ipc(IpcClientMsg), +} - while let Some(msg) = cb_rx.recv().await { - match msg { - InternalServerMsg::Ipc(msg) => tx.send(serde_json::to_string(&msg)?.into()).await?, - InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => { - tun_device.set_ips(ipv4, ipv6).await?; - dns_controller.set_dns(&dns).await?; - tx.send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into()) - .await?; - } - InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => { - tun_device.set_routes(ipv4, ipv6).await? +impl Handler { + fn new(stream: IpcStream) -> Result { + let (rx, tx) = tokio::io::split(stream); + let ipc_rx = FramedRead::new(rx, LengthDelimitedCodec::new()); + let ipc_tx = FramedWrite::new(tx, LengthDelimitedCodec::new()); + let (cb_tx, cb_rx) = mpsc::channel(10); + let tun_device = tun_device_manager::TunDeviceManager::new()?; + + Ok(Self { + callback_handler: CallbackHandler { cb_tx }, + cb_rx, + connlib: None, + dns_controller: Default::default(), + ipc_rx, + ipc_tx, + last_connlib_start_instant: None, + tun_device, + }) + } + + async fn run(&mut self) -> Result<()> { + loop { + let event = { + // This borrows `self` so we must drop it before handling the `Event`. + let cb = pin!(self.cb_rx.recv()); + match future::select(self.ipc_rx.next(), cb).await { + future::Either::Left((Some(Ok(x)), _)) => Event::Ipc( + serde_json::from_slice(&x) + .context("Error while deserializing IPC message")?, + ), // TODO: Integrate the serde_json stuff into a custom Tokio codec + future::Either::Left((Some(Err(error)), _)) => Err(error)?, + future::Either::Left((None, _)) => { + tracing::info!("IPC client disconnected"); + break; + } + future::Either::Right((Some(x), _)) => Event::Callback(x), + future::Either::Right((None, _)) => { + tracing::error!("Impossible - Callback channel closed"); + break; + } } + }; + match event { + Event::Callback(x) => self.handle_connlib_cb(x).await?, + Event::Ipc(msg) => self + .handle_ipc_msg(msg) + .context("Error while handling IPC message from client")?, } } - Ok::<_, anyhow::Error>(()) - }); + Ok(()) + } - let mut connlib = None; - let callback_handler = CallbackHandler { cb_tx }; - while let Some(msg) = rx.next().await { - let msg = msg?; - let msg: IpcClientMsg = serde_json::from_slice(&msg)?; + async fn handle_connlib_cb(&mut self, msg: InternalServerMsg) -> Result<()> { + match msg { + InternalServerMsg::Ipc(msg) => { + // The first `OnUpdateResources` marks when connlib is fully initialized + if let IpcServerMsg::OnUpdateResources(_) = &msg { + if let Some(instant) = self.last_connlib_start_instant.take() { + let dur = instant.elapsed(); + tracing::info!(?dur, "Connlib started"); + } + } + self.ipc_tx + .send(serde_json::to_string(&msg)?.into()) + .await? + } + InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => { + self.tun_device.set_ips(ipv4, ipv6).await?; + self.dns_controller.set_dns(&dns).await?; + self.ipc_tx + .send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into()) + .await?; + } + InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => { + self.tun_device.set_routes(ipv4, ipv6).await? + } + } + Ok(()) + } + fn handle_ipc_msg(&mut self, msg: IpcClientMsg) -> Result<()> { match msg { IpcClientMsg::Connect { api_url, token } => { let token = secrecy::SecretString::from(token); - assert!(connlib.is_none()); + assert!(self.connlib.is_none()); let device_id = device_id::get_or_create().context("Failed to get / create device ID")?; let (private_key, public_key) = keypair(); @@ -502,31 +574,39 @@ async fn handle_ipc_client(stream: IpcStream) -> Result<()> { public_key.to_bytes(), )?; + self.last_connlib_start_instant = Some(Instant::now()); let new_session = connlib_client_shared::Session::connect( login, Sockets::new(), private_key, None, - callback_handler.clone(), - Some(std::time::Duration::from_secs(60 * 60 * 24 * 30)), + self.callback_handler.clone(), + Some(Duration::from_secs(60 * 60 * 24 * 30)), tokio::runtime::Handle::try_current()?, ); new_session.set_dns(dns_control::system_resolvers().unwrap_or_default()); - connlib = Some(new_session); + self.connlib = Some(new_session); } IpcClientMsg::Disconnect => { - if let Some(connlib) = connlib.take() { + if let Some(connlib) = self.connlib.take() { connlib.disconnect(); + } else { + tracing::error!("Error - Got Disconnect when we're already not connected"); } } - IpcClientMsg::Reconnect => connlib.as_mut().context("No connlib session")?.reconnect(), - IpcClientMsg::SetDns(v) => connlib.as_mut().context("No connlib session")?.set_dns(v), + IpcClientMsg::Reconnect => self + .connlib + .as_mut() + .context("No connlib session")? + .reconnect(), + IpcClientMsg::SetDns(v) => self + .connlib + .as_mut() + .context("No connlib session")? + .set_dns(v), } + Ok(()) } - - send_task.abort(); - - Ok(()) } #[allow(dead_code)]