refactor(headless-client): rewrite the IPC service main loop so we can time the Client startup (#5376)

Part of a yak shave to profile startup time for reducing it on Windows
#5026

Median of 3 runs:

- Windows 11 aarch64 Parallels VM - 4.8 s
- Windows 11 x86_64 laptop - 3.1 s (I thought it used to be slower)
- Windows Server 2022 VM - 22.2 s

---------

Signed-off-by: Reactor Scram <ReactorScram@users.noreply.github.com>
Co-authored-by: Jamil <jamilbk@users.noreply.github.com>
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Reactor Scram
2024-06-17 11:57:43 -05:00
committed by GitHub
parent 1930e62404
commit a9a0a6c450
3 changed files with 134 additions and 55 deletions

View File

@@ -29,6 +29,16 @@ pub(crate) struct DnsController {
dns_control_method: Option<DnsControlMethod>,
}
impl Default for DnsController {
fn default() -> Self {
// We'll remove `get_dns_control_from_env` in #5068
let dns_control_method = get_dns_control_from_env();
tracing::info!(?dns_control_method);
Self { dns_control_method }
}
}
impl Drop for DnsController {
fn drop(&mut self) {
tracing::debug!("Reverting DNS control...");
@@ -40,14 +50,6 @@ impl Drop for DnsController {
}
impl DnsController {
pub(crate) fn new() -> Self {
// We'll remove `get_dns_control_from_env` in #5068
let dns_control_method = get_dns_control_from_env();
tracing::info!(?dns_control_method);
Self { dns_control_method }
}
/// Set the computer's system-wide DNS servers
///
/// The `mut` in `&mut self` is not needed by Rust's rules, but

View File

@@ -21,6 +21,7 @@ pub fn system_resolvers_for_gui() -> Result<Vec<IpAddr>> {
system_resolvers()
}
#[derive(Default)]
pub(crate) struct DnsController {}
// Unique magic number that we can use to delete our well-known NRPT rule.
@@ -36,10 +37,6 @@ impl Drop for DnsController {
}
impl DnsController {
pub(crate) fn new() -> Self {
Self {}
}
/// Set the computer's system-wide DNS servers
///
/// There's a gap in this because on Windows we deactivate and re-activate control.

View File

@@ -21,11 +21,16 @@ use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
pin::pin,
time::Duration,
};
use tokio::{
io::{ReadHalf, WriteHalf},
sync::mpsc,
time::Instant,
};
use tokio::sync::mpsc;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::subscriber::set_global_default;
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer as _, Registry};
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer, Registry};
use url::Url;
use platform::default_token_path;
@@ -51,6 +56,7 @@ pub mod windows;
#[cfg(target_os = "windows")]
pub(crate) use windows as platform;
use dns_control::DnsController;
use ipc::{Server as IpcServer, Stream as IpcStream};
/// Only used on Linux
@@ -298,7 +304,7 @@ pub fn run_only_headless_client() -> Result<()> {
platform::notify_service_controller()?;
let result = rt.block_on(async {
let mut dns_controller = dns_control::DnsController::new();
let mut dns_controller = dns_control::DnsController::default();
let mut tun_device = tun_device_manager::TunDeviceManager::new()?;
let mut signals = Signals::new()?;
@@ -447,49 +453,115 @@ async fn ipc_listen() -> Result<std::convert::Infallible> {
.next_client()
.await
.context("Failed to wait for incoming IPC connection from a GUI")?;
if let Err(error) = handle_ipc_client(stream).await {
tracing::error!(?error, "Error while handling IPC client");
}
Handler::new(stream)?
.run()
.await
.context("Error while handling IPC client")?;
}
}
async fn handle_ipc_client(stream: IpcStream) -> Result<()> {
let (rx, tx) = tokio::io::split(stream);
let mut rx = FramedRead::new(rx, LengthDelimitedCodec::new());
let mut tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
let (cb_tx, mut cb_rx) = mpsc::channel(10);
/// Handles one IPC client
struct Handler {
callback_handler: CallbackHandler,
cb_rx: mpsc::Receiver<InternalServerMsg>,
connlib: Option<connlib_client_shared::Session>,
dns_controller: DnsController,
ipc_rx: FramedRead<ReadHalf<IpcStream>, LengthDelimitedCodec>,
ipc_tx: FramedWrite<WriteHalf<IpcStream>, LengthDelimitedCodec>,
last_connlib_start_instant: Option<Instant>,
tun_device: tun_device_manager::TunDeviceManager,
}
let send_task = tokio::spawn(async move {
let mut dns_controller = dns_control::DnsController::new();
let mut tun_device = tun_device_manager::TunDeviceManager::new()?;
enum Event {
Callback(InternalServerMsg),
Ipc(IpcClientMsg),
}
while let Some(msg) = cb_rx.recv().await {
match msg {
InternalServerMsg::Ipc(msg) => tx.send(serde_json::to_string(&msg)?.into()).await?,
InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => {
tun_device.set_ips(ipv4, ipv6).await?;
dns_controller.set_dns(&dns).await?;
tx.send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into())
.await?;
}
InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => {
tun_device.set_routes(ipv4, ipv6).await?
impl Handler {
fn new(stream: IpcStream) -> Result<Self> {
let (rx, tx) = tokio::io::split(stream);
let ipc_rx = FramedRead::new(rx, LengthDelimitedCodec::new());
let ipc_tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
let (cb_tx, cb_rx) = mpsc::channel(10);
let tun_device = tun_device_manager::TunDeviceManager::new()?;
Ok(Self {
callback_handler: CallbackHandler { cb_tx },
cb_rx,
connlib: None,
dns_controller: Default::default(),
ipc_rx,
ipc_tx,
last_connlib_start_instant: None,
tun_device,
})
}
async fn run(&mut self) -> Result<()> {
loop {
let event = {
// This borrows `self` so we must drop it before handling the `Event`.
let cb = pin!(self.cb_rx.recv());
match future::select(self.ipc_rx.next(), cb).await {
future::Either::Left((Some(Ok(x)), _)) => Event::Ipc(
serde_json::from_slice(&x)
.context("Error while deserializing IPC message")?,
), // TODO: Integrate the serde_json stuff into a custom Tokio codec
future::Either::Left((Some(Err(error)), _)) => Err(error)?,
future::Either::Left((None, _)) => {
tracing::info!("IPC client disconnected");
break;
}
future::Either::Right((Some(x), _)) => Event::Callback(x),
future::Either::Right((None, _)) => {
tracing::error!("Impossible - Callback channel closed");
break;
}
}
};
match event {
Event::Callback(x) => self.handle_connlib_cb(x).await?,
Event::Ipc(msg) => self
.handle_ipc_msg(msg)
.context("Error while handling IPC message from client")?,
}
}
Ok::<_, anyhow::Error>(())
});
Ok(())
}
let mut connlib = None;
let callback_handler = CallbackHandler { cb_tx };
while let Some(msg) = rx.next().await {
let msg = msg?;
let msg: IpcClientMsg = serde_json::from_slice(&msg)?;
async fn handle_connlib_cb(&mut self, msg: InternalServerMsg) -> Result<()> {
match msg {
InternalServerMsg::Ipc(msg) => {
// The first `OnUpdateResources` marks when connlib is fully initialized
if let IpcServerMsg::OnUpdateResources(_) = &msg {
if let Some(instant) = self.last_connlib_start_instant.take() {
let dur = instant.elapsed();
tracing::info!(?dur, "Connlib started");
}
}
self.ipc_tx
.send(serde_json::to_string(&msg)?.into())
.await?
}
InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => {
self.tun_device.set_ips(ipv4, ipv6).await?;
self.dns_controller.set_dns(&dns).await?;
self.ipc_tx
.send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into())
.await?;
}
InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => {
self.tun_device.set_routes(ipv4, ipv6).await?
}
}
Ok(())
}
fn handle_ipc_msg(&mut self, msg: IpcClientMsg) -> Result<()> {
match msg {
IpcClientMsg::Connect { api_url, token } => {
let token = secrecy::SecretString::from(token);
assert!(connlib.is_none());
assert!(self.connlib.is_none());
let device_id =
device_id::get_or_create().context("Failed to get / create device ID")?;
let (private_key, public_key) = keypair();
@@ -502,31 +574,39 @@ async fn handle_ipc_client(stream: IpcStream) -> Result<()> {
public_key.to_bytes(),
)?;
self.last_connlib_start_instant = Some(Instant::now());
let new_session = connlib_client_shared::Session::connect(
login,
Sockets::new(),
private_key,
None,
callback_handler.clone(),
Some(std::time::Duration::from_secs(60 * 60 * 24 * 30)),
self.callback_handler.clone(),
Some(Duration::from_secs(60 * 60 * 24 * 30)),
tokio::runtime::Handle::try_current()?,
);
new_session.set_dns(dns_control::system_resolvers().unwrap_or_default());
connlib = Some(new_session);
self.connlib = Some(new_session);
}
IpcClientMsg::Disconnect => {
if let Some(connlib) = connlib.take() {
if let Some(connlib) = self.connlib.take() {
connlib.disconnect();
} else {
tracing::error!("Error - Got Disconnect when we're already not connected");
}
}
IpcClientMsg::Reconnect => connlib.as_mut().context("No connlib session")?.reconnect(),
IpcClientMsg::SetDns(v) => connlib.as_mut().context("No connlib session")?.set_dns(v),
IpcClientMsg::Reconnect => self
.connlib
.as_mut()
.context("No connlib session")?
.reconnect(),
IpcClientMsg::SetDns(v) => self
.connlib
.as_mut()
.context("No connlib session")?
.set_dns(v),
}
Ok(())
}
send_task.abort();
Ok(())
}
#[allow(dead_code)]