From 6ab7e51264642f107cfe276d57c0652ac49da178 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 14 Mar 2024 07:09:48 +1100 Subject: [PATCH] refactor(connlib): allow commands to be sent to eventloop (#4112) This refactors `Session` to allow for commands to be sent to the `Eventloop`. Currently, we only send a `Stop` command. With #3429, we will add more commands like refreshing and updating the DNS servers. --- rust/connlib/clients/apple/src/lib.rs | 4 +- rust/connlib/clients/shared/src/eventloop.rs | 20 ++- rust/connlib/clients/shared/src/lib.rs | 135 ++++++++----------- rust/connlib/shared/src/error.rs | 11 +- rust/gui-client/src-tauri/src/client/gui.rs | 2 +- rust/linux-client/src/main.rs | 7 +- 6 files changed, 82 insertions(+), 97 deletions(-) diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 17fac7b07..3561674e0 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -43,7 +43,7 @@ mod ffi { callback_handler: CallbackHandler, ) -> Result; - fn disconnect(&mut self); + fn disconnect(self); } extern "Swift" { @@ -217,7 +217,7 @@ impl WrappedSession { Ok(Self(session)) } - fn disconnect(&mut self) { + fn disconnect(self) { self.0.disconnect() } } diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 3bd54780d..21644993d 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -14,7 +14,6 @@ use firezone_tunnel::ClientTunnel; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ collections::HashMap, - convert::Infallible, io, path::PathBuf, task::{Context, Poll}, @@ -28,14 +27,22 @@ pub struct Eventloop { tunnel_init: bool, portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + rx: tokio::sync::mpsc::Receiver, + connection_intents: SentConnectionIntents, log_upload_interval: tokio::time::Interval, } +/// Commands that can be sent to the [`Eventloop`]. +pub enum Command { + Stop, +} + impl Eventloop { pub(crate) fn new( tunnel: ClientTunnel, portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + rx: tokio::sync::mpsc::Receiver, ) -> Self { Self { tunnel, @@ -43,6 +50,7 @@ impl Eventloop { tunnel_init: false, connection_intents: SentConnectionIntents::default(), log_upload_interval: upload_interval(), + rx, } } } @@ -52,11 +60,13 @@ where C: Callbacks + 'static, { #[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")] - pub fn poll( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => {} + } + match self.tunnel.poll_next_event(cx) { Poll::Ready(Ok(event)) => { self.handle_tunnel_event(event); diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index b01f07546..c159d5de7 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -15,10 +15,10 @@ mod messages; const PHOENIX_TOPIC: &str = "client"; -struct StopRuntime; - +use eventloop::Command; pub use eventloop::Eventloop; use secrecy::Secret; +use tokio::task::JoinHandle; /// Max interval to retry connections to the portal if it's down or the client has network /// connectivity changes. Set this to something short so that the end-user experiences @@ -29,7 +29,8 @@ const MAX_RECONNECT_INTERVAL: Duration = Duration::from_secs(5); /// /// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. pub struct Session { - runtime_stopper: tokio::sync::mpsc::Sender, + channel: tokio::sync::mpsc::Sender, + _runtime: tokio::runtime::Runtime, } impl Session { @@ -60,7 +61,7 @@ impl Session { // but then platforms should know that this function is blocking. let callbacks = CallbackErrorFacade(callbacks); - let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let (tx, rx) = tokio::sync::mpsc::channel(1); // In android we get an stack-overflow due to tokio // taking too much of the stack-space: @@ -69,81 +70,28 @@ impl Session { .thread_stack_size(3 * 1024 * 1024) .enable_all() .build()?; - { - let callbacks = callbacks.clone(); - let default_panic_hook = std::panic::take_hook(); - std::panic::set_hook(Box::new({ - let tx = tx.clone(); - move |info| { - let tx = tx.clone(); - let err = info - .payload() - .downcast_ref::<&str>() - .map(|s| Error::Panic(s.to_string())) - .unwrap_or(Error::PanicNonStringPayload( - info.location().map(ToString::to_string), - )); - Self::disconnect_inner(tx, &callbacks, Some(err)); - default_panic_hook(info); - } - })); - } - runtime.spawn(connect( + let connect_handle = runtime.spawn(connect( url, private_key, os_version_override, - callbacks, + callbacks.clone(), max_partition_time, + rx, )); - - std::thread::spawn(move || { - rx.blocking_recv(); - runtime.shutdown_background(); - }); + runtime.spawn(connect_supervisor(connect_handle, callbacks)); Ok(Self { - runtime_stopper: tx, + channel: tx, + _runtime: runtime, }) } - fn disconnect_inner( - runtime_stopper: tokio::sync::mpsc::Sender, - callbacks: &CallbackErrorFacade, - error: Option, - ) { - // 1. Close the websocket connection - // 2. Free the device handle (Linux) - // 3. Close the file descriptor (Linux/Android) - // 4. Remove the mapping - - // The way we cleanup the tasks is we drop the runtime - // this means we don't need to keep track of different tasks - // but if any of the tasks never yields this will block forever! - // So always yield and if you spawn a blocking tasks rewrite this. - // Furthermore, we will depend on Drop impls to do the list above so, - // implement them :) - // if there's no receiver the runtime is already stopped - // there's an edge case where this is called before the thread is listening for stop threads. - // but I believe in that case the channel will be in a signaled state achieving the same result - - if let Err(err) = runtime_stopper.try_send(StopRuntime) { - tracing::error!("Couldn't stop runtime: {err}"); - } - - if let Some(error) = error { - let _ = callbacks.on_disconnect(&error); - } - } - - /// Cleanup a [Session]. + /// Disconnect a [`Session`]. /// - /// For now this just drops the runtime, which should drop all pending tasks. - /// Further cleanup should be done here. (Otherwise we can just drop [Session]). - pub fn disconnect(&mut self) { - if let Err(err) = self.runtime_stopper.try_send(StopRuntime) { - tracing::error!("Couldn't stop runtime: {err}"); - } + /// This consumes [`Session`] which cleans up all state associated with it. + pub fn disconnect(self) { + let _ = self.channel.try_send(Command::Stop); } } @@ -156,17 +104,12 @@ async fn connect( os_version_override: Option, callbacks: CB, max_partition_time: Option, -) where + rx: tokio::sync::mpsc::Receiver, +) -> Result<(), Error> +where CB: Callbacks + 'static, { - let tunnel = match Tunnel::new(private_key, callbacks.clone()) { - Ok(tunnel) => tunnel, - Err(e) => { - tracing::error!("Failed to make tunnel: {e}"); - let _ = callbacks.on_disconnect(&e); - return; - } - }; + let tunnel = Tunnel::new(private_key, callbacks.clone())?; let portal = PhoenixChannel::connect( Secret::new(url), @@ -179,13 +122,41 @@ async fn connect( .build(), ); - let mut eventloop = Eventloop::new(tunnel, portal); + let mut eventloop = Eventloop::new(tunnel, portal, rx); - match std::future::poll_fn(|cx| eventloop.poll(cx)).await { - Ok(never) => match never {}, - Err(e) => { - tracing::error!("Eventloop failed: {e}"); - let _ = callbacks.on_disconnect(&Error::PortalConnectionFailed); // TMP Error until we have a narrower API for `onDisconnect` + std::future::poll_fn(|cx| eventloop.poll(cx)) + .await + .map_err(Error::PortalConnectionFailed)?; + + Ok(()) +} + +/// A supervisor task that handles, when [`connect`] exits. +async fn connect_supervisor(connect_handle: JoinHandle>, callbacks: CB) +where + CB: Callbacks, +{ + match connect_handle.await { + Ok(Ok(())) => { + tracing::info!("connlib exited gracefully"); } + Ok(Err(e)) => { + tracing::error!("connlib failed: {e}"); + let _ = callbacks.on_disconnect(&e); + } + Err(e) => match e.try_into_panic() { + Ok(panic) => { + if let Some(msg) = panic.downcast_ref::<&str>() { + let _ = callbacks.on_disconnect(&Error::Panic(msg.to_string())); + return; + } + + let _ = callbacks.on_disconnect(&Error::PanicNonStringPayload); + } + Err(_) => { + tracing::error!("connlib task was cancelled"); + let _ = callbacks.on_disconnect(&Error::Cancelled); + } + }, } } diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 750af76fa..08961588e 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -100,11 +100,14 @@ pub enum ConnlibError { #[error("No MTU found")] NoMtu, /// A panic occurred. - #[error("Panicked: {0}")] + #[error("Connlib panicked: {0}")] Panic(String), + /// The task was cancelled + #[error("Connlib task was cancelled")] + Cancelled, /// A panic occurred with a non-string payload. #[error("Panicked with a non-string payload")] - PanicNonStringPayload(Option), + PanicNonStringPayload, /// Received connection details that might be stale #[error("Unexpected connection details")] UnexpectedConnectionDetails, @@ -176,8 +179,8 @@ pub enum ConnlibError { #[error("Failed to control system DNS with `resolvectl`")] ResolvectlFailed, - #[error("connection to the portal failed")] - PortalConnectionFailed, + #[error("connection to the portal failed: {0}")] + PortalConnectionFailed(phoenix_channel::Error), } impl ConnlibError { diff --git a/rust/gui-client/src-tauri/src/client/gui.rs b/rust/gui-client/src-tauri/src/client/gui.rs index 742d05533..80451fad8 100644 --- a/rust/gui-client/src-tauri/src/client/gui.rs +++ b/rust/gui-client/src-tauri/src/client/gui.rs @@ -745,7 +745,7 @@ impl Controller { fn sign_out(&mut self) -> Result<()> { self.auth.sign_out()?; self.tunnel_ready = false; - if let Some(mut session) = self.session.take() { + if let Some(session) = self.session.take() { tracing::debug!("disconnecting connlib"); // This is redundant if the token is expired, in that case // connlib already disconnected itself. diff --git a/rust/linux-client/src/main.rs b/rust/linux-client/src/main.rs index 0f84f5842..1410bc820 100644 --- a/rust/linux-client/src/main.rs +++ b/rust/linux-client/src/main.rs @@ -38,7 +38,7 @@ fn main() -> Result<()> { public_key.to_bytes(), )?; - let mut session = + let session = Session::connect(login, private_key, None, callbacks, max_partition_time).unwrap(); block_on_ctrl_c(); @@ -83,8 +83,9 @@ impl Callbacks for CallbackHandler { } fn on_disconnect(&self, error: &connlib_client_shared::Error) -> Result<(), Self::Error> { - tracing::error!(?error, "Disconnected"); - Ok(()) + tracing::error!("Disconnected: {error}"); + + std::process::exit(1); } fn roll_log_file(&self) -> Option {