refactor(connlib): allow commands to be sent to eventloop (#4112)

This refactors `Session` to allow for commands to be sent to the
`Eventloop`. Currently, we only send a `Stop` command. With #3429, we
will add more commands like refreshing and updating the DNS servers.
This commit is contained in:
Thomas Eizinger
2024-03-14 07:09:48 +11:00
committed by GitHub
parent 7fd3d1a6b1
commit 6ab7e51264
6 changed files with 82 additions and 97 deletions

View File

@@ -43,7 +43,7 @@ mod ffi {
callback_handler: CallbackHandler,
) -> Result<WrappedSession, String>;
fn disconnect(&mut self);
fn disconnect(self);
}
extern "Swift" {
@@ -217,7 +217,7 @@ impl WrappedSession {
Ok(Self(session))
}
fn disconnect(&mut self) {
fn disconnect(self) {
self.0.disconnect()
}
}

View File

@@ -14,7 +14,6 @@ use firezone_tunnel::ClientTunnel;
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
convert::Infallible,
io,
path::PathBuf,
task::{Context, Poll},
@@ -28,14 +27,22 @@ pub struct Eventloop<C: Callbacks> {
tunnel_init: bool,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
connection_intents: SentConnectionIntents,
log_upload_interval: tokio::time::Interval,
}
/// Commands that can be sent to the [`Eventloop`].
pub enum Command {
Stop,
}
impl<C: Callbacks> Eventloop<C> {
pub(crate) fn new(
tunnel: ClientTunnel<C>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
) -> Self {
Self {
tunnel,
@@ -43,6 +50,7 @@ impl<C: Callbacks> Eventloop<C> {
tunnel_init: false,
connection_intents: SentConnectionIntents::default(),
log_upload_interval: upload_interval(),
rx,
}
}
}
@@ -52,11 +60,13 @@ where
C: Callbacks + 'static,
{
#[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")]
pub fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Infallible, phoenix_channel::Error>> {
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), phoenix_channel::Error>> {
loop {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => {}
}
match self.tunnel.poll_next_event(cx) {
Poll::Ready(Ok(event)) => {
self.handle_tunnel_event(event);

View File

@@ -15,10 +15,10 @@ mod messages;
const PHOENIX_TOPIC: &str = "client";
struct StopRuntime;
use eventloop::Command;
pub use eventloop::Eventloop;
use secrecy::Secret;
use tokio::task::JoinHandle;
/// Max interval to retry connections to the portal if it's down or the client has network
/// connectivity changes. Set this to something short so that the end-user experiences
@@ -29,7 +29,8 @@ const MAX_RECONNECT_INTERVAL: Duration = Duration::from_secs(5);
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session {
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
channel: tokio::sync::mpsc::Sender<Command>,
_runtime: tokio::runtime::Runtime,
}
impl Session {
@@ -60,7 +61,7 @@ impl Session {
// but then platforms should know that this function is blocking.
let callbacks = CallbackErrorFacade(callbacks);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let (tx, rx) = tokio::sync::mpsc::channel(1);
// In android we get an stack-overflow due to tokio
// taking too much of the stack-space:
@@ -69,81 +70,28 @@ impl Session {
.thread_stack_size(3 * 1024 * 1024)
.enable_all()
.build()?;
{
let callbacks = callbacks.clone();
let default_panic_hook = std::panic::take_hook();
std::panic::set_hook(Box::new({
let tx = tx.clone();
move |info| {
let tx = tx.clone();
let err = info
.payload()
.downcast_ref::<&str>()
.map(|s| Error::Panic(s.to_string()))
.unwrap_or(Error::PanicNonStringPayload(
info.location().map(ToString::to_string),
));
Self::disconnect_inner(tx, &callbacks, Some(err));
default_panic_hook(info);
}
}));
}
runtime.spawn(connect(
let connect_handle = runtime.spawn(connect(
url,
private_key,
os_version_override,
callbacks,
callbacks.clone(),
max_partition_time,
rx,
));
std::thread::spawn(move || {
rx.blocking_recv();
runtime.shutdown_background();
});
runtime.spawn(connect_supervisor(connect_handle, callbacks));
Ok(Self {
runtime_stopper: tx,
channel: tx,
_runtime: runtime,
})
}
fn disconnect_inner<CB: Callbacks + 'static>(
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
callbacks: &CallbackErrorFacade<CB>,
error: Option<Error>,
) {
// 1. Close the websocket connection
// 2. Free the device handle (Linux)
// 3. Close the file descriptor (Linux/Android)
// 4. Remove the mapping
// The way we cleanup the tasks is we drop the runtime
// this means we don't need to keep track of different tasks
// but if any of the tasks never yields this will block forever!
// So always yield and if you spawn a blocking tasks rewrite this.
// Furthermore, we will depend on Drop impls to do the list above so,
// implement them :)
// if there's no receiver the runtime is already stopped
// there's an edge case where this is called before the thread is listening for stop threads.
// but I believe in that case the channel will be in a signaled state achieving the same result
if let Err(err) = runtime_stopper.try_send(StopRuntime) {
tracing::error!("Couldn't stop runtime: {err}");
}
if let Some(error) = error {
let _ = callbacks.on_disconnect(&error);
}
}
/// Cleanup a [Session].
/// Disconnect a [`Session`].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self) {
if let Err(err) = self.runtime_stopper.try_send(StopRuntime) {
tracing::error!("Couldn't stop runtime: {err}");
}
/// This consumes [`Session`] which cleans up all state associated with it.
pub fn disconnect(self) {
let _ = self.channel.try_send(Command::Stop);
}
}
@@ -156,17 +104,12 @@ async fn connect<CB>(
os_version_override: Option<String>,
callbacks: CB,
max_partition_time: Option<Duration>,
) where
rx: tokio::sync::mpsc::Receiver<Command>,
) -> Result<(), Error>
where
CB: Callbacks + 'static,
{
let tunnel = match Tunnel::new(private_key, callbacks.clone()) {
Ok(tunnel) => tunnel,
Err(e) => {
tracing::error!("Failed to make tunnel: {e}");
let _ = callbacks.on_disconnect(&e);
return;
}
};
let tunnel = Tunnel::new(private_key, callbacks.clone())?;
let portal = PhoenixChannel::connect(
Secret::new(url),
@@ -179,13 +122,41 @@ async fn connect<CB>(
.build(),
);
let mut eventloop = Eventloop::new(tunnel, portal);
let mut eventloop = Eventloop::new(tunnel, portal, rx);
match std::future::poll_fn(|cx| eventloop.poll(cx)).await {
Ok(never) => match never {},
Err(e) => {
tracing::error!("Eventloop failed: {e}");
let _ = callbacks.on_disconnect(&Error::PortalConnectionFailed); // TMP Error until we have a narrower API for `onDisconnect`
std::future::poll_fn(|cx| eventloop.poll(cx))
.await
.map_err(Error::PortalConnectionFailed)?;
Ok(())
}
/// A supervisor task that handles, when [`connect`] exits.
async fn connect_supervisor<CB>(connect_handle: JoinHandle<Result<(), Error>>, callbacks: CB)
where
CB: Callbacks,
{
match connect_handle.await {
Ok(Ok(())) => {
tracing::info!("connlib exited gracefully");
}
Ok(Err(e)) => {
tracing::error!("connlib failed: {e}");
let _ = callbacks.on_disconnect(&e);
}
Err(e) => match e.try_into_panic() {
Ok(panic) => {
if let Some(msg) = panic.downcast_ref::<&str>() {
let _ = callbacks.on_disconnect(&Error::Panic(msg.to_string()));
return;
}
let _ = callbacks.on_disconnect(&Error::PanicNonStringPayload);
}
Err(_) => {
tracing::error!("connlib task was cancelled");
let _ = callbacks.on_disconnect(&Error::Cancelled);
}
},
}
}

View File

@@ -100,11 +100,14 @@ pub enum ConnlibError {
#[error("No MTU found")]
NoMtu,
/// A panic occurred.
#[error("Panicked: {0}")]
#[error("Connlib panicked: {0}")]
Panic(String),
/// The task was cancelled
#[error("Connlib task was cancelled")]
Cancelled,
/// A panic occurred with a non-string payload.
#[error("Panicked with a non-string payload")]
PanicNonStringPayload(Option<String>),
PanicNonStringPayload,
/// Received connection details that might be stale
#[error("Unexpected connection details")]
UnexpectedConnectionDetails,
@@ -176,8 +179,8 @@ pub enum ConnlibError {
#[error("Failed to control system DNS with `resolvectl`")]
ResolvectlFailed,
#[error("connection to the portal failed")]
PortalConnectionFailed,
#[error("connection to the portal failed: {0}")]
PortalConnectionFailed(phoenix_channel::Error),
}
impl ConnlibError {

View File

@@ -745,7 +745,7 @@ impl Controller {
fn sign_out(&mut self) -> Result<()> {
self.auth.sign_out()?;
self.tunnel_ready = false;
if let Some(mut session) = self.session.take() {
if let Some(session) = self.session.take() {
tracing::debug!("disconnecting connlib");
// This is redundant if the token is expired, in that case
// connlib already disconnected itself.

View File

@@ -38,7 +38,7 @@ fn main() -> Result<()> {
public_key.to_bytes(),
)?;
let mut session =
let session =
Session::connect(login, private_key, None, callbacks, max_partition_time).unwrap();
block_on_ctrl_c();
@@ -83,8 +83,9 @@ impl Callbacks for CallbackHandler {
}
fn on_disconnect(&self, error: &connlib_client_shared::Error) -> Result<(), Self::Error> {
tracing::error!(?error, "Disconnected");
Ok(())
tracing::error!("Disconnected: {error}");
std::process::exit(1);
}
fn roll_log_file(&self) -> Option<PathBuf> {