mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(headless-client): rewrite the IPC service main loop so we can time the Client startup (#5376)
Part of a yak shave to profile startup time for reducing it on Windows #5026 Median of 3 runs: - Windows 11 aarch64 Parallels VM - 4.8 s - Windows 11 x86_64 laptop - 3.1 s (I thought it used to be slower) - Windows Server 2022 VM - 22.2 s --------- Signed-off-by: Reactor Scram <ReactorScram@users.noreply.github.com> Co-authored-by: Jamil <jamilbk@users.noreply.github.com> Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
@@ -29,6 +29,16 @@ pub(crate) struct DnsController {
|
||||
dns_control_method: Option<DnsControlMethod>,
|
||||
}
|
||||
|
||||
impl Default for DnsController {
|
||||
fn default() -> Self {
|
||||
// We'll remove `get_dns_control_from_env` in #5068
|
||||
let dns_control_method = get_dns_control_from_env();
|
||||
tracing::info!(?dns_control_method);
|
||||
|
||||
Self { dns_control_method }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DnsController {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!("Reverting DNS control...");
|
||||
@@ -40,14 +50,6 @@ impl Drop for DnsController {
|
||||
}
|
||||
|
||||
impl DnsController {
|
||||
pub(crate) fn new() -> Self {
|
||||
// We'll remove `get_dns_control_from_env` in #5068
|
||||
let dns_control_method = get_dns_control_from_env();
|
||||
tracing::info!(?dns_control_method);
|
||||
|
||||
Self { dns_control_method }
|
||||
}
|
||||
|
||||
/// Set the computer's system-wide DNS servers
|
||||
///
|
||||
/// The `mut` in `&mut self` is not needed by Rust's rules, but
|
||||
|
||||
@@ -21,6 +21,7 @@ pub fn system_resolvers_for_gui() -> Result<Vec<IpAddr>> {
|
||||
system_resolvers()
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct DnsController {}
|
||||
|
||||
// Unique magic number that we can use to delete our well-known NRPT rule.
|
||||
@@ -36,10 +37,6 @@ impl Drop for DnsController {
|
||||
}
|
||||
|
||||
impl DnsController {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Set the computer's system-wide DNS servers
|
||||
///
|
||||
/// There's a gap in this because on Windows we deactivate and re-activate control.
|
||||
|
||||
@@ -21,11 +21,16 @@ use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
path::{Path, PathBuf},
|
||||
pin::pin,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io::{ReadHalf, WriteHalf},
|
||||
sync::mpsc,
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
|
||||
use tracing::subscriber::set_global_default;
|
||||
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer as _, Registry};
|
||||
use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Layer, Registry};
|
||||
use url::Url;
|
||||
|
||||
use platform::default_token_path;
|
||||
@@ -51,6 +56,7 @@ pub mod windows;
|
||||
#[cfg(target_os = "windows")]
|
||||
pub(crate) use windows as platform;
|
||||
|
||||
use dns_control::DnsController;
|
||||
use ipc::{Server as IpcServer, Stream as IpcStream};
|
||||
|
||||
/// Only used on Linux
|
||||
@@ -298,7 +304,7 @@ pub fn run_only_headless_client() -> Result<()> {
|
||||
platform::notify_service_controller()?;
|
||||
|
||||
let result = rt.block_on(async {
|
||||
let mut dns_controller = dns_control::DnsController::new();
|
||||
let mut dns_controller = dns_control::DnsController::default();
|
||||
let mut tun_device = tun_device_manager::TunDeviceManager::new()?;
|
||||
let mut signals = Signals::new()?;
|
||||
|
||||
@@ -447,49 +453,115 @@ async fn ipc_listen() -> Result<std::convert::Infallible> {
|
||||
.next_client()
|
||||
.await
|
||||
.context("Failed to wait for incoming IPC connection from a GUI")?;
|
||||
if let Err(error) = handle_ipc_client(stream).await {
|
||||
tracing::error!(?error, "Error while handling IPC client");
|
||||
}
|
||||
Handler::new(stream)?
|
||||
.run()
|
||||
.await
|
||||
.context("Error while handling IPC client")?;
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ipc_client(stream: IpcStream) -> Result<()> {
|
||||
let (rx, tx) = tokio::io::split(stream);
|
||||
let mut rx = FramedRead::new(rx, LengthDelimitedCodec::new());
|
||||
let mut tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
|
||||
let (cb_tx, mut cb_rx) = mpsc::channel(10);
|
||||
/// Handles one IPC client
|
||||
struct Handler {
|
||||
callback_handler: CallbackHandler,
|
||||
cb_rx: mpsc::Receiver<InternalServerMsg>,
|
||||
connlib: Option<connlib_client_shared::Session>,
|
||||
dns_controller: DnsController,
|
||||
ipc_rx: FramedRead<ReadHalf<IpcStream>, LengthDelimitedCodec>,
|
||||
ipc_tx: FramedWrite<WriteHalf<IpcStream>, LengthDelimitedCodec>,
|
||||
last_connlib_start_instant: Option<Instant>,
|
||||
tun_device: tun_device_manager::TunDeviceManager,
|
||||
}
|
||||
|
||||
let send_task = tokio::spawn(async move {
|
||||
let mut dns_controller = dns_control::DnsController::new();
|
||||
let mut tun_device = tun_device_manager::TunDeviceManager::new()?;
|
||||
enum Event {
|
||||
Callback(InternalServerMsg),
|
||||
Ipc(IpcClientMsg),
|
||||
}
|
||||
|
||||
while let Some(msg) = cb_rx.recv().await {
|
||||
match msg {
|
||||
InternalServerMsg::Ipc(msg) => tx.send(serde_json::to_string(&msg)?.into()).await?,
|
||||
InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => {
|
||||
tun_device.set_ips(ipv4, ipv6).await?;
|
||||
dns_controller.set_dns(&dns).await?;
|
||||
tx.send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into())
|
||||
.await?;
|
||||
}
|
||||
InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => {
|
||||
tun_device.set_routes(ipv4, ipv6).await?
|
||||
impl Handler {
|
||||
fn new(stream: IpcStream) -> Result<Self> {
|
||||
let (rx, tx) = tokio::io::split(stream);
|
||||
let ipc_rx = FramedRead::new(rx, LengthDelimitedCodec::new());
|
||||
let ipc_tx = FramedWrite::new(tx, LengthDelimitedCodec::new());
|
||||
let (cb_tx, cb_rx) = mpsc::channel(10);
|
||||
let tun_device = tun_device_manager::TunDeviceManager::new()?;
|
||||
|
||||
Ok(Self {
|
||||
callback_handler: CallbackHandler { cb_tx },
|
||||
cb_rx,
|
||||
connlib: None,
|
||||
dns_controller: Default::default(),
|
||||
ipc_rx,
|
||||
ipc_tx,
|
||||
last_connlib_start_instant: None,
|
||||
tun_device,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(&mut self) -> Result<()> {
|
||||
loop {
|
||||
let event = {
|
||||
// This borrows `self` so we must drop it before handling the `Event`.
|
||||
let cb = pin!(self.cb_rx.recv());
|
||||
match future::select(self.ipc_rx.next(), cb).await {
|
||||
future::Either::Left((Some(Ok(x)), _)) => Event::Ipc(
|
||||
serde_json::from_slice(&x)
|
||||
.context("Error while deserializing IPC message")?,
|
||||
), // TODO: Integrate the serde_json stuff into a custom Tokio codec
|
||||
future::Either::Left((Some(Err(error)), _)) => Err(error)?,
|
||||
future::Either::Left((None, _)) => {
|
||||
tracing::info!("IPC client disconnected");
|
||||
break;
|
||||
}
|
||||
future::Either::Right((Some(x), _)) => Event::Callback(x),
|
||||
future::Either::Right((None, _)) => {
|
||||
tracing::error!("Impossible - Callback channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
match event {
|
||||
Event::Callback(x) => self.handle_connlib_cb(x).await?,
|
||||
Event::Ipc(msg) => self
|
||||
.handle_ipc_msg(msg)
|
||||
.context("Error while handling IPC message from client")?,
|
||||
}
|
||||
}
|
||||
Ok::<_, anyhow::Error>(())
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let mut connlib = None;
|
||||
let callback_handler = CallbackHandler { cb_tx };
|
||||
while let Some(msg) = rx.next().await {
|
||||
let msg = msg?;
|
||||
let msg: IpcClientMsg = serde_json::from_slice(&msg)?;
|
||||
async fn handle_connlib_cb(&mut self, msg: InternalServerMsg) -> Result<()> {
|
||||
match msg {
|
||||
InternalServerMsg::Ipc(msg) => {
|
||||
// The first `OnUpdateResources` marks when connlib is fully initialized
|
||||
if let IpcServerMsg::OnUpdateResources(_) = &msg {
|
||||
if let Some(instant) = self.last_connlib_start_instant.take() {
|
||||
let dur = instant.elapsed();
|
||||
tracing::info!(?dur, "Connlib started");
|
||||
}
|
||||
}
|
||||
self.ipc_tx
|
||||
.send(serde_json::to_string(&msg)?.into())
|
||||
.await?
|
||||
}
|
||||
InternalServerMsg::OnSetInterfaceConfig { ipv4, ipv6, dns } => {
|
||||
self.tun_device.set_ips(ipv4, ipv6).await?;
|
||||
self.dns_controller.set_dns(&dns).await?;
|
||||
self.ipc_tx
|
||||
.send(serde_json::to_string(&IpcServerMsg::OnTunnelReady)?.into())
|
||||
.await?;
|
||||
}
|
||||
InternalServerMsg::OnUpdateRoutes { ipv4, ipv6 } => {
|
||||
self.tun_device.set_routes(ipv4, ipv6).await?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_ipc_msg(&mut self, msg: IpcClientMsg) -> Result<()> {
|
||||
match msg {
|
||||
IpcClientMsg::Connect { api_url, token } => {
|
||||
let token = secrecy::SecretString::from(token);
|
||||
assert!(connlib.is_none());
|
||||
assert!(self.connlib.is_none());
|
||||
let device_id =
|
||||
device_id::get_or_create().context("Failed to get / create device ID")?;
|
||||
let (private_key, public_key) = keypair();
|
||||
@@ -502,31 +574,39 @@ async fn handle_ipc_client(stream: IpcStream) -> Result<()> {
|
||||
public_key.to_bytes(),
|
||||
)?;
|
||||
|
||||
self.last_connlib_start_instant = Some(Instant::now());
|
||||
let new_session = connlib_client_shared::Session::connect(
|
||||
login,
|
||||
Sockets::new(),
|
||||
private_key,
|
||||
None,
|
||||
callback_handler.clone(),
|
||||
Some(std::time::Duration::from_secs(60 * 60 * 24 * 30)),
|
||||
self.callback_handler.clone(),
|
||||
Some(Duration::from_secs(60 * 60 * 24 * 30)),
|
||||
tokio::runtime::Handle::try_current()?,
|
||||
);
|
||||
new_session.set_dns(dns_control::system_resolvers().unwrap_or_default());
|
||||
connlib = Some(new_session);
|
||||
self.connlib = Some(new_session);
|
||||
}
|
||||
IpcClientMsg::Disconnect => {
|
||||
if let Some(connlib) = connlib.take() {
|
||||
if let Some(connlib) = self.connlib.take() {
|
||||
connlib.disconnect();
|
||||
} else {
|
||||
tracing::error!("Error - Got Disconnect when we're already not connected");
|
||||
}
|
||||
}
|
||||
IpcClientMsg::Reconnect => connlib.as_mut().context("No connlib session")?.reconnect(),
|
||||
IpcClientMsg::SetDns(v) => connlib.as_mut().context("No connlib session")?.set_dns(v),
|
||||
IpcClientMsg::Reconnect => self
|
||||
.connlib
|
||||
.as_mut()
|
||||
.context("No connlib session")?
|
||||
.reconnect(),
|
||||
IpcClientMsg::SetDns(v) => self
|
||||
.connlib
|
||||
.as_mut()
|
||||
.context("No connlib session")?
|
||||
.set_dns(v),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
send_task.abort();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
||||
Reference in New Issue
Block a user