mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): allow commands to be sent to eventloop (#4112)
This refactors `Session` to allow for commands to be sent to the `Eventloop`. Currently, we only send a `Stop` command. With #3429, we will add more commands like refreshing and updating the DNS servers.
This commit is contained in:
@@ -43,7 +43,7 @@ mod ffi {
|
||||
callback_handler: CallbackHandler,
|
||||
) -> Result<WrappedSession, String>;
|
||||
|
||||
fn disconnect(&mut self);
|
||||
fn disconnect(self);
|
||||
}
|
||||
|
||||
extern "Swift" {
|
||||
@@ -217,7 +217,7 @@ impl WrappedSession {
|
||||
Ok(Self(session))
|
||||
}
|
||||
|
||||
fn disconnect(&mut self) {
|
||||
fn disconnect(self) {
|
||||
self.0.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
io,
|
||||
path::PathBuf,
|
||||
task::{Context, Poll},
|
||||
@@ -28,14 +27,22 @@ pub struct Eventloop<C: Callbacks> {
|
||||
tunnel_init: bool,
|
||||
|
||||
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
|
||||
connection_intents: SentConnectionIntents,
|
||||
log_upload_interval: tokio::time::Interval,
|
||||
}
|
||||
|
||||
/// Commands that can be sent to the [`Eventloop`].
|
||||
pub enum Command {
|
||||
Stop,
|
||||
}
|
||||
|
||||
impl<C: Callbacks> Eventloop<C> {
|
||||
pub(crate) fn new(
|
||||
tunnel: ClientTunnel<C>,
|
||||
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tunnel,
|
||||
@@ -43,6 +50,7 @@ impl<C: Callbacks> Eventloop<C> {
|
||||
tunnel_init: false,
|
||||
connection_intents: SentConnectionIntents::default(),
|
||||
log_upload_interval: upload_interval(),
|
||||
rx,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -52,11 +60,13 @@ where
|
||||
C: Callbacks + 'static,
|
||||
{
|
||||
#[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")]
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Infallible, phoenix_channel::Error>> {
|
||||
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), phoenix_channel::Error>> {
|
||||
loop {
|
||||
match self.rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
match self.tunnel.poll_next_event(cx) {
|
||||
Poll::Ready(Ok(event)) => {
|
||||
self.handle_tunnel_event(event);
|
||||
|
||||
@@ -15,10 +15,10 @@ mod messages;
|
||||
|
||||
const PHOENIX_TOPIC: &str = "client";
|
||||
|
||||
struct StopRuntime;
|
||||
|
||||
use eventloop::Command;
|
||||
pub use eventloop::Eventloop;
|
||||
use secrecy::Secret;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Max interval to retry connections to the portal if it's down or the client has network
|
||||
/// connectivity changes. Set this to something short so that the end-user experiences
|
||||
@@ -29,7 +29,8 @@ const MAX_RECONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
///
|
||||
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
|
||||
pub struct Session {
|
||||
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
|
||||
channel: tokio::sync::mpsc::Sender<Command>,
|
||||
_runtime: tokio::runtime::Runtime,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -60,7 +61,7 @@ impl Session {
|
||||
// but then platforms should know that this function is blocking.
|
||||
|
||||
let callbacks = CallbackErrorFacade(callbacks);
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
// In android we get an stack-overflow due to tokio
|
||||
// taking too much of the stack-space:
|
||||
@@ -69,81 +70,28 @@ impl Session {
|
||||
.thread_stack_size(3 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
{
|
||||
let callbacks = callbacks.clone();
|
||||
let default_panic_hook = std::panic::take_hook();
|
||||
std::panic::set_hook(Box::new({
|
||||
let tx = tx.clone();
|
||||
move |info| {
|
||||
let tx = tx.clone();
|
||||
let err = info
|
||||
.payload()
|
||||
.downcast_ref::<&str>()
|
||||
.map(|s| Error::Panic(s.to_string()))
|
||||
.unwrap_or(Error::PanicNonStringPayload(
|
||||
info.location().map(ToString::to_string),
|
||||
));
|
||||
Self::disconnect_inner(tx, &callbacks, Some(err));
|
||||
default_panic_hook(info);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
runtime.spawn(connect(
|
||||
let connect_handle = runtime.spawn(connect(
|
||||
url,
|
||||
private_key,
|
||||
os_version_override,
|
||||
callbacks,
|
||||
callbacks.clone(),
|
||||
max_partition_time,
|
||||
rx,
|
||||
));
|
||||
|
||||
std::thread::spawn(move || {
|
||||
rx.blocking_recv();
|
||||
runtime.shutdown_background();
|
||||
});
|
||||
runtime.spawn(connect_supervisor(connect_handle, callbacks));
|
||||
|
||||
Ok(Self {
|
||||
runtime_stopper: tx,
|
||||
channel: tx,
|
||||
_runtime: runtime,
|
||||
})
|
||||
}
|
||||
|
||||
fn disconnect_inner<CB: Callbacks + 'static>(
|
||||
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
|
||||
callbacks: &CallbackErrorFacade<CB>,
|
||||
error: Option<Error>,
|
||||
) {
|
||||
// 1. Close the websocket connection
|
||||
// 2. Free the device handle (Linux)
|
||||
// 3. Close the file descriptor (Linux/Android)
|
||||
// 4. Remove the mapping
|
||||
|
||||
// The way we cleanup the tasks is we drop the runtime
|
||||
// this means we don't need to keep track of different tasks
|
||||
// but if any of the tasks never yields this will block forever!
|
||||
// So always yield and if you spawn a blocking tasks rewrite this.
|
||||
// Furthermore, we will depend on Drop impls to do the list above so,
|
||||
// implement them :)
|
||||
// if there's no receiver the runtime is already stopped
|
||||
// there's an edge case where this is called before the thread is listening for stop threads.
|
||||
// but I believe in that case the channel will be in a signaled state achieving the same result
|
||||
|
||||
if let Err(err) = runtime_stopper.try_send(StopRuntime) {
|
||||
tracing::error!("Couldn't stop runtime: {err}");
|
||||
}
|
||||
|
||||
if let Some(error) = error {
|
||||
let _ = callbacks.on_disconnect(&error);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleanup a [Session].
|
||||
/// Disconnect a [`Session`].
|
||||
///
|
||||
/// For now this just drops the runtime, which should drop all pending tasks.
|
||||
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
|
||||
pub fn disconnect(&mut self) {
|
||||
if let Err(err) = self.runtime_stopper.try_send(StopRuntime) {
|
||||
tracing::error!("Couldn't stop runtime: {err}");
|
||||
}
|
||||
/// This consumes [`Session`] which cleans up all state associated with it.
|
||||
pub fn disconnect(self) {
|
||||
let _ = self.channel.try_send(Command::Stop);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,17 +104,12 @@ async fn connect<CB>(
|
||||
os_version_override: Option<String>,
|
||||
callbacks: CB,
|
||||
max_partition_time: Option<Duration>,
|
||||
) where
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = match Tunnel::new(private_key, callbacks.clone()) {
|
||||
Ok(tunnel) => tunnel,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to make tunnel: {e}");
|
||||
let _ = callbacks.on_disconnect(&e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let tunnel = Tunnel::new(private_key, callbacks.clone())?;
|
||||
|
||||
let portal = PhoenixChannel::connect(
|
||||
Secret::new(url),
|
||||
@@ -179,13 +122,41 @@ async fn connect<CB>(
|
||||
.build(),
|
||||
);
|
||||
|
||||
let mut eventloop = Eventloop::new(tunnel, portal);
|
||||
let mut eventloop = Eventloop::new(tunnel, portal, rx);
|
||||
|
||||
match std::future::poll_fn(|cx| eventloop.poll(cx)).await {
|
||||
Ok(never) => match never {},
|
||||
Err(e) => {
|
||||
tracing::error!("Eventloop failed: {e}");
|
||||
let _ = callbacks.on_disconnect(&Error::PortalConnectionFailed); // TMP Error until we have a narrower API for `onDisconnect`
|
||||
std::future::poll_fn(|cx| eventloop.poll(cx))
|
||||
.await
|
||||
.map_err(Error::PortalConnectionFailed)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A supervisor task that handles, when [`connect`] exits.
|
||||
async fn connect_supervisor<CB>(connect_handle: JoinHandle<Result<(), Error>>, callbacks: CB)
|
||||
where
|
||||
CB: Callbacks,
|
||||
{
|
||||
match connect_handle.await {
|
||||
Ok(Ok(())) => {
|
||||
tracing::info!("connlib exited gracefully");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!("connlib failed: {e}");
|
||||
let _ = callbacks.on_disconnect(&e);
|
||||
}
|
||||
Err(e) => match e.try_into_panic() {
|
||||
Ok(panic) => {
|
||||
if let Some(msg) = panic.downcast_ref::<&str>() {
|
||||
let _ = callbacks.on_disconnect(&Error::Panic(msg.to_string()));
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = callbacks.on_disconnect(&Error::PanicNonStringPayload);
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("connlib task was cancelled");
|
||||
let _ = callbacks.on_disconnect(&Error::Cancelled);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,11 +100,14 @@ pub enum ConnlibError {
|
||||
#[error("No MTU found")]
|
||||
NoMtu,
|
||||
/// A panic occurred.
|
||||
#[error("Panicked: {0}")]
|
||||
#[error("Connlib panicked: {0}")]
|
||||
Panic(String),
|
||||
/// The task was cancelled
|
||||
#[error("Connlib task was cancelled")]
|
||||
Cancelled,
|
||||
/// A panic occurred with a non-string payload.
|
||||
#[error("Panicked with a non-string payload")]
|
||||
PanicNonStringPayload(Option<String>),
|
||||
PanicNonStringPayload,
|
||||
/// Received connection details that might be stale
|
||||
#[error("Unexpected connection details")]
|
||||
UnexpectedConnectionDetails,
|
||||
@@ -176,8 +179,8 @@ pub enum ConnlibError {
|
||||
#[error("Failed to control system DNS with `resolvectl`")]
|
||||
ResolvectlFailed,
|
||||
|
||||
#[error("connection to the portal failed")]
|
||||
PortalConnectionFailed,
|
||||
#[error("connection to the portal failed: {0}")]
|
||||
PortalConnectionFailed(phoenix_channel::Error),
|
||||
}
|
||||
|
||||
impl ConnlibError {
|
||||
|
||||
@@ -745,7 +745,7 @@ impl Controller {
|
||||
fn sign_out(&mut self) -> Result<()> {
|
||||
self.auth.sign_out()?;
|
||||
self.tunnel_ready = false;
|
||||
if let Some(mut session) = self.session.take() {
|
||||
if let Some(session) = self.session.take() {
|
||||
tracing::debug!("disconnecting connlib");
|
||||
// This is redundant if the token is expired, in that case
|
||||
// connlib already disconnected itself.
|
||||
|
||||
@@ -38,7 +38,7 @@ fn main() -> Result<()> {
|
||||
public_key.to_bytes(),
|
||||
)?;
|
||||
|
||||
let mut session =
|
||||
let session =
|
||||
Session::connect(login, private_key, None, callbacks, max_partition_time).unwrap();
|
||||
|
||||
block_on_ctrl_c();
|
||||
@@ -83,8 +83,9 @@ impl Callbacks for CallbackHandler {
|
||||
}
|
||||
|
||||
fn on_disconnect(&self, error: &connlib_client_shared::Error) -> Result<(), Self::Error> {
|
||||
tracing::error!(?error, "Disconnected");
|
||||
Ok(())
|
||||
tracing::error!("Disconnected: {error}");
|
||||
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
fn roll_log_file(&self) -> Option<PathBuf> {
|
||||
|
||||
Reference in New Issue
Block a user