From e5e18e78a386f6b1db5bed66b9ef7d5b8293e323 Mon Sep 17 00:00:00 2001 From: Francesca Lovebloom Date: Wed, 19 Jul 2023 15:36:06 -0700 Subject: [PATCH] connlib: Disconnect on fatal error (#1801) Resolves firezone/product#619 This additionally removes `ErrorType`: - `on_error` is now exclusively used for recoverable errors, and no longer has an `error_type` parameter. - `on_disconnect` now has an optional `error` parameter, which specifies the fatal error that caused the disconnect if relevant. --- rust/Cargo.lock | 1 + rust/connlib/clients/android/src/lib.rs | 10 +- .../Sources/Connlib/CallbackHandler.swift | 11 +- rust/connlib/clients/apple/src/lib.rs | 39 +++--- rust/connlib/clients/headless/src/main.rs | 15 +-- rust/connlib/gateway/src/main.rs | 17 +-- rust/connlib/libs/client/src/control.rs | 44 +++---- rust/connlib/libs/client/src/lib.rs | 4 +- rust/connlib/libs/common/Cargo.toml | 1 + rust/connlib/libs/common/src/error_type.rs | 16 --- rust/connlib/libs/common/src/lib.rs | 1 - rust/connlib/libs/common/src/session.rs | 115 +++++++++++------- rust/connlib/libs/gateway/src/control.rs | 27 ++-- rust/connlib/libs/gateway/src/lib.rs | 2 +- .../libs/tunnel/src/control_protocol.rs | 11 +- rust/connlib/libs/tunnel/src/lib.rs | 37 +++--- rust/connlib/libs/tunnel/src/peer.rs | 4 +- 17 files changed, 166 insertions(+), 189 deletions(-) delete mode 100644 rust/connlib/libs/common/src/error_type.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0554d3868..ae14b7bee 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1654,6 +1654,7 @@ dependencies = [ "futures-util", "ip_network", "os_info", + "parking_lot", "rand", "rand_core 0.6.4", "rtnetlink", diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index f6c4fdf83..b0da22484 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -3,9 +3,7 @@ // However, this consideration has made it idiomatic for Java FFI in the Rust // ecosystem, so it's used here for consistency. -use firezone_client_connlib::{ - Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, -}; +use firezone_client_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses}; use jni::{ objects::{JClass, JObject, JString, JValue}, JNIEnv, @@ -51,11 +49,11 @@ impl Callbacks for CallbackHandler { todo!() } - fn on_disconnect(&self) { + fn on_disconnect(&self, _error: Option<&Error>) { todo!() } - fn on_error(&self, _error: &Error, _error_type: ErrorType) { + fn on_error(&self, _error: &Error) { todo!() } } @@ -108,7 +106,7 @@ pub unsafe extern "system" fn Java_dev_firezone_connlib_Session_disconnect( } let session = unsafe { &mut *session_ptr }; - session.disconnect() + session.disconnect(None) } /// # Safety diff --git a/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift b/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift index f98399379..4ff9afdda 100644 --- a/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift +++ b/rust/connlib/clients/apple/Sources/Connlib/CallbackHandler.swift @@ -51,13 +51,14 @@ public class CallbackHandler { delegate?.onUpdateResources(resourceList: resourceList.resources.toString()) } - func onDisconnect() { - logger.debug("CallbackHandler.onDisconnect") + func onDisconnect(error: SwiftConnlibError) { + logger.debug("CallbackHandler.onDisconnect: \(error, privacy: .public)") + // TODO: convert `error` to `Optional` by checking for `None` case delegate?.onDisconnect() } - func onError(error: SwiftConnlibError, error_type: SwiftErrorType) { - logger.debug("CallbackHandler.onError: \(error, privacy: .public) (\(error_type == .Recoverable ? "Recoverable" : "Fatal", privacy: .public)") - delegate?.onError(error: error, isRecoverable: error_type == .Recoverable) + func onError(error: SwiftConnlibError) { + logger.debug("CallbackHandler.onError: \(error, privacy: .public)") + delegate?.onError(error: error, isRecoverable: true) } } diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index cbfdde784..cec561b0e 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -2,9 +2,7 @@ // Swift bridge generated code triggers this below #![allow(improper_ctypes, non_camel_case_types)] -use firezone_client_connlib::{ - Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, -}; +use firezone_client_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses}; use std::{net::Ipv4Addr, sync::Arc}; #[swift_bridge::bridge] @@ -23,6 +21,9 @@ mod ffi { // TODO: Duplicating these enum variants from `libs/common/src/error.rs` is // brittle/noisy/tedious enum SwiftConnlibError { + // `swift-bridge` doesn't seem to support `Option` for Swift function + // arguments... + None, Io, Base64DecodeError, Base64DecodeSliceError, @@ -46,11 +47,6 @@ mod ffi { NoMtu, } - enum SwiftErrorType { - Recoverable, - Fatal, - } - extern "Rust" { type WrappedSession; @@ -89,10 +85,10 @@ mod ffi { fn on_update_resources(&self, resourceList: ResourceList); #[swift_bridge(swift_name = "onDisconnect")] - fn on_disconnect(&self); + fn on_disconnect(&self, error: SwiftConnlibError); #[swift_bridge(swift_name = "onError")] - fn on_error(&self, error: SwiftConnlibError, error_type: SwiftErrorType); + fn on_error(&self, error: SwiftConnlibError); } } @@ -130,15 +126,6 @@ impl From for ffi::SwiftConnlibError { } } -impl From for ffi::SwiftErrorType { - fn from(val: ErrorType) -> Self { - match val { - ErrorType::Recoverable => Self::Recoverable, - ErrorType::Fatal => Self::Fatal, - } - } -} - impl From for ffi::ResourceList { fn from(value: ResourceList) -> Self { Self { @@ -195,12 +182,16 @@ impl Callbacks for CallbackHandler { self.0.on_update_resources(resource_list.into()) } - fn on_disconnect(&self) { - self.0.on_disconnect() + fn on_disconnect(&self, error: Option<&Error>) { + self.0.on_disconnect( + error + .map(Into::into) + .unwrap_or(ffi::SwiftConnlibError::None), + ) } - fn on_error(&self, error: &Error, error_type: ErrorType) { - self.0.on_error(error.into(), error_type.into()) + fn on_error(&self, error: &Error) { + self.0.on_error(error.into()) } } @@ -230,6 +221,6 @@ impl WrappedSession { } fn disconnect(&mut self) -> bool { - self.session.disconnect() + self.session.disconnect(None) } } diff --git a/rust/connlib/clients/headless/src/main.rs b/rust/connlib/clients/headless/src/main.rs index e794cc990..26bf5321c 100644 --- a/rust/connlib/clients/headless/src/main.rs +++ b/rust/connlib/clients/headless/src/main.rs @@ -3,7 +3,7 @@ use clap::Parser; use std::{net::Ipv4Addr, str::FromStr}; use firezone_client_connlib::{ - get_user_agent, Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, + get_user_agent, Callbacks, Error, ResourceList, Session, TunnelAddresses, }; use url::Url; @@ -25,15 +25,12 @@ impl Callbacks for CallbackHandler { tracing::trace!("Resources updated, current list: {resource_list:?}"); } - fn on_disconnect(&self) { - tracing::trace!("Tunnel disconnected"); + fn on_disconnect(&self, error: Option<&Error>) { + tracing::trace!("Tunnel disconnected: {error:?}"); } - fn on_error(&self, error: &Error, error_type: ErrorType) { - match error_type { - ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"), - ErrorType::Fatal => panic!("Encountered fatal error: {error}"), - } + fn on_error(&self, error: &Error) { + tracing::warn!("Encountered recoverable error: {error}"); } } @@ -54,7 +51,7 @@ fn main() -> Result<()> { let mut session = Session::connect(url, secret, CallbackHandler).unwrap(); tracing::info!("Started new session"); session.wait_for_ctrl_c().unwrap(); - session.disconnect(); + session.disconnect(None); Ok(()) } diff --git a/rust/connlib/gateway/src/main.rs b/rust/connlib/gateway/src/main.rs index 248f7dc06..a67d4e7c2 100644 --- a/rust/connlib/gateway/src/main.rs +++ b/rust/connlib/gateway/src/main.rs @@ -1,9 +1,7 @@ use anyhow::{Context, Result}; use std::{net::Ipv4Addr, str::FromStr}; -use firezone_gateway_connlib::{ - Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses, -}; +use firezone_gateway_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses}; use url::Url; #[derive(Clone)] @@ -24,15 +22,12 @@ impl Callbacks for CallbackHandler { tracing::trace!("Resources updated, current list: {resource_list:?}"); } - fn on_disconnect(&self) { - tracing::trace!("Tunnel disconnected"); + fn on_disconnect(&self, error: Option<&Error>) { + tracing::trace!("Tunnel disconnected: {error:?}"); } - fn on_error(&self, error: &Error, error_type: ErrorType) { - match error_type { - ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"), - ErrorType::Fatal => panic!("Encountered fatal error: {error}"), - } + fn on_error(&self, error: &Error) { + tracing::warn!("Encountered recoverable error: {error}"); } } @@ -46,7 +41,7 @@ fn main() -> Result<()> { let secret = parse_env_var::(SECRET_ENV_VAR)?; let mut session = Session::connect(url, secret, CallbackHandler).unwrap(); session.wait_for_ctrl_c().unwrap(); - session.disconnect(); + session.disconnect(None); Ok(()) } diff --git a/rust/connlib/libs/client/src/control.rs b/rust/connlib/libs/client/src/control.rs index 5462f3ab1..0a081e11b 100644 --- a/rust/connlib/libs/client/src/control.rs +++ b/rust/connlib/libs/client/src/control.rs @@ -4,7 +4,6 @@ use crate::messages::{Connect, EgressMessages, InitClient, Messages, Relays}; use boringtun::x25519::StaticSecret; use libs_common::{ control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic}, - error_type::ErrorType::{self, Fatal, Recoverable}, messages::{Id, ResourceDescription}, Callbacks, ControlSession, Error, Result, }; @@ -45,13 +44,13 @@ struct ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver>) { + async fn start(mut self, mut receiver: Receiver>) -> Result<()> { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { Some(msg) = receiver.recv() => { match msg { - Ok(msg) => self.handle_message(msg).await, + Ok(msg) => self.handle_message(msg).await?, Err(msg_reply) => self.handle_error(msg_reply).await, } }, @@ -59,6 +58,7 @@ impl ControlPlane { else => break } } + Ok(()) } #[tracing::instrument(level = "trace", skip_all)] @@ -68,18 +68,17 @@ impl ControlPlane { interface, resources, }: InitClient, - ) { + ) -> Result<()> { if let Err(e) = self.tunnel.set_interface(&interface).await { tracing::error!("Couldn't initialize interface: {e}"); - self.tunnel.callbacks().on_error(&e, Fatal); - return; + Err(e) + } else { + for resource_description in resources { + self.add_resource(resource_description).await?; + } + tracing::info!("Firezoned Started!"); + Ok(()) } - - for resource_description in resources { - self.add_resource(resource_description).await - } - - tracing::info!("Firezoned Started!"); } #[tracing::instrument(level = "trace", skip(self))] @@ -101,13 +100,13 @@ impl ControlPlane { ) .await { - self.tunnel.callbacks().on_error(&e, Recoverable); + self.tunnel.callbacks().on_error(&e); } } #[tracing::instrument(level = "trace", skip(self))] - async fn add_resource(&self, resource_description: ResourceDescription) { - self.tunnel.add_resource(resource_description).await; + async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> { + self.tunnel.add_resource(resource_description).await } #[tracing::instrument(level = "trace", skip(self))] @@ -143,27 +142,28 @@ impl ControlPlane { .await { tunnel.cleanup_connection(resource_id); - tunnel.callbacks().on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err); } } Err(err) => { tunnel.cleanup_connection(resource_id); - tunnel.callbacks().on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err); } } }); } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_message(&mut self, msg: Messages) { + pub(super) async fn handle_message(&mut self, msg: Messages) -> Result<()> { match msg { - Messages::Init(init) => self.init(init).await, + Messages::Init(init) => self.init(init).await?, Messages::Relays(connection_details) => self.relays(connection_details), Messages::Connect(connect) => self.connect(connect).await, - Messages::ResourceAdded(resource) => self.add_resource(resource).await, + Messages::ResourceAdded(resource) => self.add_resource(resource).await?, Messages::ResourceRemoved(resource) => self.remove_resource(resource.id), Messages::ResourceUpdated(resource) => self.update_resource(resource), } + Ok(()) } #[tracing::instrument(level = "trace", skip(self))] @@ -175,7 +175,7 @@ impl ControlPlane { tracing::error!( "An offline error came back with a reference to a non-valid resource id" ); - self.tunnel.callbacks().on_error(&Error::ControlProtocolError, ErrorType::Recoverable); + self.tunnel.callbacks().on_error(&Error::ControlProtocolError); return; }; // TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection @@ -187,7 +187,7 @@ impl ControlPlane { ); self.tunnel .callbacks() - .on_error(&Error::ControlProtocolError, ErrorType::Recoverable); + .on_error(&Error::ControlProtocolError); } } } diff --git a/rust/connlib/libs/client/src/lib.rs b/rust/connlib/libs/client/src/lib.rs index 2cc10699c..3f0fca162 100644 --- a/rust/connlib/libs/client/src/lib.rs +++ b/rust/connlib/libs/client/src/lib.rs @@ -18,8 +18,6 @@ pub type Session = libs_common::Session< CB, >; -pub use libs_common::{ - error_type::ErrorType, get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses, -}; +pub use libs_common::{get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses}; use messages::Messages; use messages::ReplyMessages; diff --git a/rust/connlib/libs/common/Cargo.toml b/rust/connlib/libs/common/Cargo.toml index 3774f1ce9..e927c2547 100644 --- a/rust/connlib/libs/common/Cargo.toml +++ b/rust/connlib/libs/common/Cargo.toml @@ -29,6 +29,7 @@ boringtun = { workspace = true } os_info = { version = "3", default-features = false } rand = { version = "0.8", default-features = false, features = ["std"] } chrono = { workspace = true } +parking_lot = "0.12" [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] swift-bridge = { workspace = true } diff --git a/rust/connlib/libs/common/src/error_type.rs b/rust/connlib/libs/common/src/error_type.rs deleted file mode 100644 index a0e99638f..000000000 --- a/rust/connlib/libs/common/src/error_type.rs +++ /dev/null @@ -1,16 +0,0 @@ -//! Module that contains the Error-Type that hints how to handle an error to upper layers. - -/// This indicates whether the produced error is something recoverable or fatal. -/// Fata/Recoverable only indicates how to handle the error for the client. -/// -/// Any of the errors in [ConnlibError][crate::error::ConnlibError] could be of any [ErrorType] depending the circumstance. -#[derive(Debug, Clone, Copy)] -pub enum ErrorType { - /// Recoverable means that the session can continue - /// e.g. Failed to send an SDP - Recoverable, - /// Fatal error means that the session should stop and start again, - /// generally after user input, such as clicking connect once more. - /// e.g. Max number of retries was reached when trying to connect to the portal. - Fatal, -} diff --git a/rust/connlib/libs/common/src/lib.rs b/rust/connlib/libs/common/src/lib.rs index 6c5459c39..6c696169d 100644 --- a/rust/connlib/libs/common/src/lib.rs +++ b/rust/connlib/libs/common/src/lib.rs @@ -4,7 +4,6 @@ //! we are using the same version across our own crates. pub mod error; -pub mod error_type; mod session; diff --git a/rust/connlib/libs/common/src/session.rs b/rust/connlib/libs/common/src/session.rs index 9e2269266..d39083a22 100644 --- a/rust/connlib/libs/common/src/session.rs +++ b/rust/connlib/libs/common/src/session.rs @@ -1,11 +1,13 @@ use async_trait::async_trait; use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use boringtun::x25519::{PublicKey, StaticSecret}; +use parking_lot::Mutex; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand_core::OsRng; use std::{ marker::PhantomData, net::{Ipv4Addr, Ipv6Addr}, + sync::Arc, time::Duration, }; use tokio::{runtime::Runtime, sync::mpsc::Receiver}; @@ -14,7 +16,6 @@ use uuid::Uuid; use crate::{ control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic}, - error_type::ErrorType, messages::{Key, ResourceDescription, ResourceDescriptionCidr}, Error, Result, }; @@ -44,7 +45,7 @@ pub trait ControlSession { /// /// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. pub struct Session { - runtime: Option, + runtime: Arc>>, callbacks: CB, _phantom: PhantomData<(T, U, V, R, M)>, } @@ -64,7 +65,6 @@ pub struct TunnelAddresses { pub address6: Ipv6Addr, } -// Evaluate doing this not static /// Traits that will be used by connlib to callback the client upper layers. pub trait Callbacks: Clone + Send + Sync { /// Called when the tunnel address is set. @@ -78,21 +78,20 @@ pub trait Callbacks: Clone + Send + Sync { /// Called when the resource list changes. fn on_update_resources(&self, resource_list: ResourceList); /// Called when the tunnel is disconnected. - fn on_disconnect(&self); - /// Called when there's an error. /// - /// # Parameters - /// - `error`: The actual error that happened. - /// - `error_type`: Whether the error should terminate the session or not. - fn on_error(&self, error: &Error, error_type: ErrorType); + /// If the tunnel disconnected due to a fatal error, `error` is the error + /// that caused the disconnect. + fn on_disconnect(&self, error: Option<&Error>); + /// Called when there's a recoverable error. + fn on_error(&self, error: &Error); } macro_rules! fatal_error { - ($result:expr, $c:expr) => { + ($result:expr, $rt:expr, $cb:expr) => { match $result { Ok(res) => res, - Err(e) => { - $c.on_error(&e, ErrorType::Fatal); + Err(err) => { + Self::disconnect_inner($rt, $cb, Some(err)); return; } } @@ -112,6 +111,7 @@ where /// (Used for the gateways). pub fn wait_for_ctrl_c(&mut self) -> Result<()> { self.runtime + .lock() .as_ref() .ok_or(Error::NoRuntime)? .block_on(async { @@ -138,32 +138,48 @@ where // Big question here however is how do we get the result? We could block here await the result and spawn a new task. // but then platforms should know that this function is blocking. - let portal_url = portal_url.try_into().map_err(|_| Error::UriError)?; - - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()?; - - if cfg!(feature = "mock") { - Self::connect_mock(callbacks.clone()); - } else { - Self::connect_inner(&runtime, portal_url, token, callbacks.clone()); - } - - Ok(Self { - runtime: Some(runtime), + let this = Self { + runtime: Mutex::new(Some( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?, + )) + .into(), callbacks, _phantom: PhantomData, - }) + }; + + if cfg!(feature = "mock") { + Self::connect_mock(this.callbacks.clone()); + } else { + Self::connect_inner( + Arc::clone(&this.runtime), + portal_url.try_into().map_err(|_| Error::UriError)?, + token, + this.callbacks.clone(), + ); + } + + Ok(this) } - fn connect_inner(runtime: &Runtime, portal_url: Url, token: String, callbacks: CB) { - runtime.spawn(async move { + fn connect_inner( + runtime: Arc>>, + portal_url: Url, + token: String, + callbacks: CB, + ) { + let runtime_disconnector = Arc::clone(&runtime); + runtime.lock().as_ref().unwrap().spawn(async move { let private_key = StaticSecret::random_from_rng(OsRng); let self_id = uuid::Uuid::new_v4(); let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect(); - let connect_url = fatal_error!(get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string(), &name_suffix), callbacks); + let connect_url = fatal_error!( + get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string(), &name_suffix), + &runtime_disconnector, + &callbacks + ); // This is kinda hacky, the buffer size is 1 so that we make sure that we @@ -184,7 +200,11 @@ where // Used to send internal messages let topic = T::socket_path().to_string(); let internal_sender = connection.sender_with_topic(topic.clone()); - fatal_error!(T::start(private_key, control_plane_receiver, internal_sender, callbacks.clone()).await, callbacks); + fatal_error!( + T::start(private_key, control_plane_receiver, internal_sender, callbacks.clone()).await, + &runtime_disconnector, + &callbacks + ); tokio::spawn(async move { let mut exponential_backoff = ExponentialBackoffBuilder::default().build(); @@ -193,18 +213,15 @@ where let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await; if let Some(t) = exponential_backoff.next_backoff() { tracing::warn!("Error during connection to the portal, retrying in {} seconds", t.as_secs()); - match result { - Ok(()) => callbacks.on_error(&tokio_tungstenite::tungstenite::Error::ConnectionClosed.into(), ErrorType::Recoverable), - Err(e) => callbacks.on_error(&e, ErrorType::Recoverable) - } + callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))); tokio::time::sleep(t).await; } else { tracing::error!("Connection to the portal error, check your internet or the status of the portal.\nDisconnecting interface."); - match result { - Ok(()) => callbacks.on_error(&crate::Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed), ErrorType::Fatal), - Err(e) => callbacks.on_error(&e, ErrorType::Fatal) - } - break; + fatal_error!( + result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))), + &runtime_disconnector, + &callbacks + ); } } @@ -251,12 +268,8 @@ where }); } - /// Cleanup a [Session]. - /// - /// For now this just drops the runtime, which should drop all pending tasks. - /// Further cleanup should be done here. (Otherwise we can just drop [Session]). - pub fn disconnect(&mut self) -> bool { - self.callbacks.on_disconnect(); + fn disconnect_inner(runtime: &Mutex>, callbacks: &CB, error: Option) { + callbacks.on_disconnect(error.as_ref()); // 1. Close the websocket connection // 2. Free the device handle (UNIX) @@ -269,7 +282,15 @@ where // So always yield and if you spawn a blocking tasks rewrite this. // Furthermore, we will depend on Drop impls to do the list above so, // implement them :) - self.runtime = None; + *runtime.lock() = None; + } + + /// Cleanup a [Session]. + /// + /// For now this just drops the runtime, which should drop all pending tasks. + /// Further cleanup should be done here. (Otherwise we can just drop [Session]). + pub fn disconnect(&mut self, error: Option) -> bool { + Self::disconnect_inner(&self.runtime, &self.callbacks, error); true } diff --git a/rust/connlib/libs/gateway/src/control.rs b/rust/connlib/libs/gateway/src/control.rs index 754bc7b6e..e2da522c5 100644 --- a/rust/connlib/libs/gateway/src/control.rs +++ b/rust/connlib/libs/gateway/src/control.rs @@ -4,7 +4,6 @@ use boringtun::x25519::StaticSecret; use firezone_tunnel::{ControlSignal, Tunnel}; use libs_common::{ control::{MessageResult, PhoenixSenderWithTopic}, - error_type::ErrorType::{Fatal, Recoverable}, messages::ResourceDescription, Callbacks, ControlSession, Result, }; @@ -36,13 +35,13 @@ impl ControlSignal for ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver>) { + async fn start(mut self, mut receiver: Receiver>) -> Result<()> { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { Some(msg) = receiver.recv() => { match msg { - Ok(msg) => self.handle_message(msg).await, + Ok(msg) => self.handle_message(msg).await?, Err(_msg_reply) => todo!(), } }, @@ -50,18 +49,19 @@ impl ControlPlane { else => break } } + Ok(()) } #[tracing::instrument(level = "trace", skip_all)] - async fn init(&mut self, init: InitGateway) { + async fn init(&mut self, init: InitGateway) -> Result<()> { if let Err(e) = self.tunnel.set_interface(&init.interface).await { tracing::error!("Couldn't initialize interface: {e}"); - self.tunnel.callbacks().on_error(&e, Fatal); - return; + Err(e) + } else { + // TODO: Enable masquerading here. + tracing::info!("Firezoned Started!"); + Ok(()) } - - // TODO: Enable masquerading here. - tracing::info!("Firezoned Started!"); } #[tracing::instrument(level = "trace", skip(self))] @@ -89,12 +89,12 @@ impl ControlPlane { .await { tunnel.cleanup_connection(connection_request.device.id); - tunnel.callbacks().on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err); } } Err(err) => { tunnel.cleanup_connection(connection_request.device.id); - tunnel.callbacks().on_error(&err, Recoverable); + tunnel.callbacks().on_error(&err); } } }); @@ -106,9 +106,9 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_message(&mut self, msg: IngressMessages) { + pub(super) async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> { match msg { - IngressMessages::Init(init) => self.init(init).await, + IngressMessages::Init(init) => self.init(init).await?, IngressMessages::RequestConnection(connection_request) => { self.connection_request(connection_request) } @@ -116,6 +116,7 @@ impl ControlPlane { IngressMessages::RemoveResource(_) => todo!(), IngressMessages::UpdateResource(_) => todo!(), } + Ok(()) } #[tracing::instrument(level = "trace", skip(self))] diff --git a/rust/connlib/libs/gateway/src/lib.rs b/rust/connlib/libs/gateway/src/lib.rs index 41e3394f7..b94f2c1d0 100644 --- a/rust/connlib/libs/gateway/src/lib.rs +++ b/rust/connlib/libs/gateway/src/lib.rs @@ -19,4 +19,4 @@ pub type Session = libs_common::Session< CB, >; -pub use libs_common::{error_type::ErrorType, Callbacks, Error, ResourceList, TunnelAddresses}; +pub use libs_common::{Callbacks, Error, ResourceList, TunnelAddresses}; diff --git a/rust/connlib/libs/tunnel/src/control_protocol.rs b/rust/connlib/libs/tunnel/src/control_protocol.rs index 290147364..d8b80c171 100644 --- a/rust/connlib/libs/tunnel/src/control_protocol.rs +++ b/rust/connlib/libs/tunnel/src/control_protocol.rs @@ -6,7 +6,6 @@ use chrono::{DateTime, Utc}; use std::sync::Arc; use libs_common::{ - error_type::ErrorType::Recoverable, messages::{Id, Key, Relay, RequestConnection}, Callbacks, Error, Result, }; @@ -165,7 +164,7 @@ where let Some(gateway_public_key) = tunnel.gateway_public_keys.lock().remove(&resource_id) else { tunnel.cleanup_connection(resource_id); tracing::warn!("Opened ICE channel with gateway without ever receiving public key"); - tunnel.callbacks.on_error(&Error::ControlProtocolError, Recoverable); + tunnel.callbacks.on_error(&Error::ControlProtocolError); return; }; let peer_config = PeerConfig { @@ -177,7 +176,7 @@ where if let Err(e) = tunnel.handle_channel_open(d, index, peer_config, None, resource_id).await { tracing::error!("Couldn't establish wireguard link after channel was opened: {e}"); - tunnel.callbacks.on_error(&e, Recoverable); + tunnel.callbacks.on_error(&e); tunnel.cleanup_connection(resource_id); } tunnel.awaiting_connection.lock().remove(&resource_id); @@ -283,7 +282,7 @@ where for ip in &peer.ips { if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await { - tunnel.callbacks.on_error(&e, Recoverable); + tunnel.callbacks.on_error(&e); } } } @@ -298,7 +297,7 @@ where ) .await { - tunnel.callbacks.on_error(&e, Recoverable); + tunnel.callbacks.on_error(&e); tracing::error!( "Couldn't establish wireguard link after opening channel: {e}" ); @@ -308,7 +307,7 @@ where if let Some(conn) = conn { if let Err(e) = conn.close().await { tracing::error!("Problem while trying to close channel: {e:?}"); - tunnel.callbacks().on_error(&e.into(), Recoverable); + tunnel.callbacks().on_error(&e.into()); } } } diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index a3a9a89d8..4e23d15e0 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -11,10 +11,7 @@ use boringtun::{ }; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use libs_common::{ - error_type::ErrorType::{Fatal, Recoverable}, - Callbacks, -}; +use libs_common::Callbacks; use async_trait::async_trait; use bytes::Bytes; @@ -215,26 +212,20 @@ where /// Once added, when a packet for the resource is intercepted a new data channel will be created /// and packets will be wrapped with wireguard and sent through it. #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_resource(&self, resource_description: ResourceDescription) { + pub async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> { { let mut iface_config = self.iface_config.lock().await; for ip in resource_description.ips() { - if let Err(err) = iface_config.add_route(&ip, self.callbacks()).await { - self.callbacks.on_error(&err, Fatal); - } + iface_config.add_route(&ip, self.callbacks()).await?; } } let resource_list = { let mut resources = self.resources.write(); resources.insert(resource_description); - resources.resource_list() + resources.resource_list()? }; - match resource_list { - Ok(resource_list) => { - self.callbacks.on_update_resources(resource_list); - } - Err(err) => self.callbacks.on_error(&err.into(), Fatal), - } + self.callbacks.on_update_resources(resource_list); + Ok(()) } /// Sets the interface configuration and starts background tasks. @@ -440,13 +431,13 @@ where async fn write4_device_infallible(&self, packet: &[u8]) { if let Err(e) = self.device_channel.write4(packet).await { - self.callbacks.on_error(&e.into(), Recoverable); + self.callbacks.on_error(&e.into()); } } async fn write6_device_infallible(&self, packet: &[u8]) { if let Err(e) = self.device_channel.write6(packet).await { - self.callbacks.on_error(&e.into(), Recoverable); + self.callbacks.on_error(&e.into()); } } @@ -476,13 +467,13 @@ where Ok(res) => res, Err(err) => { tracing::error!("Couldn't read packet from interface: {err}"); - dev.callbacks.on_error(&err.into(), Recoverable); + dev.callbacks.on_error(&err.into()); continue; } }, Err(err) => { tracing::error!("Couldn't obtain iface mtu: {err}"); - dev.callbacks.on_error(&err, Recoverable); + dev.callbacks.on_error(&err); continue; } } @@ -525,7 +516,7 @@ where // Not a deadlock because this is a different task dev.awaiting_connection.lock().remove(&id); tracing::error!("couldn't start protocol for new connection to resource: {e}"); - dev.callbacks.on_error(&e, Recoverable); + dev.callbacks.on_error(&e); } }); } @@ -544,7 +535,7 @@ where } TunnResult::Err(e) => { tracing::error!(message = "Encapsulate error for resource corresponding to {dst_addr}", error = ?e); - dev.callbacks.on_error(&e.into(), Recoverable); + dev.callbacks.on_error(&e.into()); } TunnResult::WriteToNetwork(packet) => { tracing::trace!("writing iface packet to peer: {dst_addr}"); @@ -565,11 +556,11 @@ where tracing::error!( "Problem while trying to close channel: {e:?}" ); - dev.callbacks().on_error(&e.into(), Recoverable); + dev.callbacks().on_error(&e.into()); } } } - dev.callbacks.on_error(&e.into(), Recoverable); + dev.callbacks.on_error(&e.into()); } } _ => panic!("Unexpected result from encapsulate"), diff --git a/rust/connlib/libs/tunnel/src/peer.rs b/rust/connlib/libs/tunnel/src/peer.rs index f278857b1..618d7f03f 100644 --- a/rust/connlib/libs/tunnel/src/peer.rs +++ b/rust/connlib/libs/tunnel/src/peer.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use chrono::{DateTime, Utc}; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use libs_common::{error_type::ErrorType, messages::Id, Callbacks, Result}; +use libs_common::{messages::Id, Callbacks, Result}; use parking_lot::Mutex; use webrtc::data::data_channel::DataChannel; @@ -24,7 +24,7 @@ impl Peer { pub(crate) async fn send_infallible(&self, data: &[u8], callbacks: &CB) { if let Err(e) = self.channel.write(&Bytes::copy_from_slice(data)).await { tracing::error!("Couldn't send packet to connected peer: {e}"); - callbacks.on_error(&e.into(), ErrorType::Recoverable); + callbacks.on_error(&e.into()); } }