refactor(rust/gui-client): close callbacks when closing connlib (#6590)

Closes #6576

This recreates the callback channel on every connect / disconnect cycle,
to prevent this sequence:

1. Start connlib
2. Fail in `make_tun`
3. Spend several seconds doing platform-specific things
4. Stop connlib (since `make_tun` failed)
5. Come back to the main loop to find a bunch of queued-up callbacks
even though connlib is supposed to be stopped.


Instead we get:

5\. Come back to the main loop and we've dropped the callback receiver,
so any callbacks that connlib sent while we were busy are either dropped
or not even sent.
This commit is contained in:
Reactor Scram
2024-09-04 14:20:45 -05:00
committed by GitHub
parent 700b056cd2
commit 9bc60dc618

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::{bail, Context as _, Result};
use clap::Parser;
use connlib_client_shared::{keypair, ConnectArgs, LoginUrl, Session};
use connlib_client_shared::{keypair, ConnectArgs, LoginUrl};
use connlib_shared::callbacks::ResourceDescription;
use firezone_bin_shared::{
platform::{tcp_socket_factory, udp_socket_factory, DnsControlMethod},
@@ -236,17 +236,20 @@ async fn ipc_listen(
/// Handles one IPC client
struct Handler<'a> {
callback_handler: CallbackHandler,
cb_rx: mpsc::Receiver<ConnlibMsg>,
connlib: Option<connlib_client_shared::Session>,
dns_controller: &'a mut DnsController,
ipc_rx: ipc::ServerRead,
ipc_tx: ipc::ServerWrite,
last_connlib_start_instant: Option<Instant>,
log_filter_reloader: &'a LogFilterReloader,
session: Option<Session>,
tun_device: TunDeviceManager,
}
struct Session {
cb_rx: mpsc::Receiver<ConnlibMsg>,
connlib: connlib_client_shared::Session,
}
enum Event {
Callback(ConnlibMsg),
CallbackChannelClosed,
@@ -275,18 +278,15 @@ impl<'a> Handler<'a> {
.next_client_split()
.await
.context("Failed to wait for incoming IPC connection from a GUI")?;
let (cb_tx, cb_rx) = mpsc::channel(1_000);
let tun_device = TunDeviceManager::new(DEFAULT_MTU)?;
Ok(Self {
callback_handler: CallbackHandler { cb_tx },
cb_rx,
connlib: None,
dns_controller,
ipc_rx,
ipc_tx,
last_connlib_start_instant: None,
log_filter_reloader,
session: None,
tun_device,
})
}
@@ -353,18 +353,20 @@ impl<'a> Handler<'a> {
}
// `FramedRead::next` is cancel-safe.
if let Poll::Ready(result) = pin!(&mut self.ipc_rx).poll_next(cx) {
return match result {
Some(Ok(x)) => Poll::Ready(Event::Ipc(x)),
Some(Err(error)) => Poll::Ready(Event::IpcError(error)),
None => Poll::Ready(Event::IpcDisconnected),
};
return Poll::Ready(match result {
Some(Ok(x)) => Event::Ipc(x),
Some(Err(error)) => Event::IpcError(error),
None => Event::IpcDisconnected,
});
}
// `tokio::sync::mpsc::Receiver::recv` is cancel-safe.
if let Poll::Ready(option) = self.cb_rx.poll_recv(cx) {
return match option {
Some(x) => Poll::Ready(Event::Callback(x)),
None => Poll::Ready(Event::CallbackChannelClosed),
};
if let Some(session) = self.session.as_mut() {
// `tokio::sync::mpsc::Receiver::recv` is cancel-safe.
if let Poll::Ready(option) = session.cb_rx.poll_recv(cx) {
return Poll::Ready(match option {
Some(x) => Event::Callback(x),
None => Event::CallbackChannelClosed,
});
}
}
Poll::Pending
}
@@ -409,10 +411,6 @@ impl<'a> Handler<'a> {
Ok(())
}
fn tunnel_is_ready(&self) -> bool {
self.last_connlib_start_instant.is_none() && self.connlib.is_some()
}
async fn handle_ipc_msg(&mut self, msg: ClientMsg) -> Result<()> {
match msg {
ClientMsg::ClearLogs => {
@@ -438,36 +436,48 @@ impl<'a> Handler<'a> {
.context("Failed to send `ConnectResult`")?
}
ClientMsg::Disconnect => {
if let Some(connlib) = self.connlib.take() {
connlib.disconnect();
self.dns_controller.deactivate()?;
} else {
let Some(session) = self.session.take() else {
tracing::error!("Error - Got Disconnect when we're already not connected");
}
return Ok(());
};
// Identical to dropping it, but looks nicer.
session.connlib.disconnect();
self.dns_controller.deactivate()?;
}
ClientMsg::ReloadLogFilter => {
let filter = spawn_blocking(get_log_filter).await??;
self.log_filter_reloader.reload(filter)?;
}
ClientMsg::Reset => {
if self.tunnel_is_ready() {
self.connlib.as_mut().context("No connlib session")?.reset();
} else {
tracing::debug!("Ignoring redundant reset");
if self.last_connlib_start_instant.is_some() {
tracing::debug!("Ignoring reset since we're still signing in");
return Ok(());
}
let Some(session) = self.session.as_ref() else {
tracing::debug!("Cannot reset if we're signed out");
return Ok(());
};
session.connlib.reset();
}
ClientMsg::SetDns(resolvers) => {
let Some(session) = self.session.as_ref() else {
tracing::debug!("Cannot set DNS resolvers if we're signed out");
return Ok(());
};
tracing::debug!(?resolvers);
self.connlib
.as_mut()
.context("No connlib session")?
.set_dns(resolvers)
session.connlib.set_dns(resolvers);
}
ClientMsg::SetDisabledResources(disabled_resources) => {
self.connlib
.as_mut()
.context("No connlib session")?
.set_disabled_resources(disabled_resources);
let Some(session) = self.session.as_ref() else {
// At this point, the GUI has already saved the disabled Resources to disk, so it'll be correct on the next sign-in anyway.
tracing::debug!("Cannot set disabled resources if we're signed out");
return Ok(());
};
session.connlib.set_disabled_resources(disabled_resources);
}
}
Ok(())
@@ -479,10 +489,7 @@ impl<'a> Handler<'a> {
///
/// Throws matchable errors for bad URLs, unable to reach the portal, or unable to create the tunnel device
fn connect_to_firezone(&mut self, api_url: &str, token: SecretString) -> Result<(), Error> {
// There isn't an airtight way to implement a "disconnect and reconnect"
// right now because `Session::disconnect` is fire-and-forget:
// <https://github.com/firezone/firezone/blob/663367b6055ced7432866a40a60f9525db13288b/rust/connlib/clients/shared/src/lib.rs#L98-L103>
assert!(self.connlib.is_none());
assert!(self.session.is_none());
let device_id = device_id::get_or_create().map_err(|e| Error::DeviceId(e.to_string()))?;
let (private_key, public_key) = keypair();
@@ -496,11 +503,13 @@ impl<'a> Handler<'a> {
.map_err(|e| Error::LoginUrl(e.to_string()))?;
self.last_connlib_start_instant = Some(Instant::now());
let (cb_tx, cb_rx) = mpsc::channel(1_000);
let callbacks = CallbackHandler { cb_tx };
let args = ConnectArgs {
tcp_socket_factory: Arc::new(tcp_socket_factory),
udp_socket_factory: Arc::new(udp_socket_factory),
private_key,
callbacks: self.callback_handler.clone(),
callbacks,
};
// Synchronous DNS resolution here
@@ -518,16 +527,22 @@ impl<'a> Handler<'a> {
// Read the resolvers before starting connlib, in case connlib's startup interferes.
let dns = self.dns_controller.system_resolvers();
let new_session = Session::connect(args, portal, tokio::runtime::Handle::current());
let connlib = connlib_client_shared::Session::connect(
args,
portal,
tokio::runtime::Handle::current(),
);
// Call `set_dns` before `set_tun` so that the tunnel starts up with a valid list of resolvers.
tracing::debug!(?dns, "Calling `set_dns`...");
new_session.set_dns(dns);
connlib.set_dns(dns);
let tun = self
.tun_device
.make_tun()
.map_err(|e| Error::TunnelDevice(e.to_string()))?;
new_session.set_tun(Box::new(tun));
self.connlib = Some(new_session);
connlib.set_tun(Box::new(tun));
let session = Session { cb_rx, connlib };
self.session = Some(session);
Ok(())
}