connlib: Disconnect on fatal error (#1801)

Resolves firezone/product#619

This additionally removes `ErrorType`:
- `on_error` is now exclusively used for recoverable errors, and no
longer has an `error_type` parameter.
- `on_disconnect` now has an optional `error` parameter, which specifies
the fatal error that caused the disconnect if relevant.
This commit is contained in:
Francesca Lovebloom
2023-07-19 15:36:06 -07:00
committed by GitHub
parent b41c4ed9e4
commit e5e18e78a3
17 changed files with 166 additions and 189 deletions

1
rust/Cargo.lock generated
View File

@@ -1654,6 +1654,7 @@ dependencies = [
"futures-util",
"ip_network",
"os_info",
"parking_lot",
"rand",
"rand_core 0.6.4",
"rtnetlink",

View File

@@ -3,9 +3,7 @@
// However, this consideration has made it idiomatic for Java FFI in the Rust
// ecosystem, so it's used here for consistency.
use firezone_client_connlib::{
Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses,
};
use firezone_client_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses};
use jni::{
objects::{JClass, JObject, JString, JValue},
JNIEnv,
@@ -51,11 +49,11 @@ impl Callbacks for CallbackHandler {
todo!()
}
fn on_disconnect(&self) {
fn on_disconnect(&self, _error: Option<&Error>) {
todo!()
}
fn on_error(&self, _error: &Error, _error_type: ErrorType) {
fn on_error(&self, _error: &Error) {
todo!()
}
}
@@ -108,7 +106,7 @@ pub unsafe extern "system" fn Java_dev_firezone_connlib_Session_disconnect(
}
let session = unsafe { &mut *session_ptr };
session.disconnect()
session.disconnect(None)
}
/// # Safety

View File

@@ -51,13 +51,14 @@ public class CallbackHandler {
delegate?.onUpdateResources(resourceList: resourceList.resources.toString())
}
func onDisconnect() {
logger.debug("CallbackHandler.onDisconnect")
func onDisconnect(error: SwiftConnlibError) {
logger.debug("CallbackHandler.onDisconnect: \(error, privacy: .public)")
// TODO: convert `error` to `Optional` by checking for `None` case
delegate?.onDisconnect()
}
func onError(error: SwiftConnlibError, error_type: SwiftErrorType) {
logger.debug("CallbackHandler.onError: \(error, privacy: .public) (\(error_type == .Recoverable ? "Recoverable" : "Fatal", privacy: .public)")
delegate?.onError(error: error, isRecoverable: error_type == .Recoverable)
func onError(error: SwiftConnlibError) {
logger.debug("CallbackHandler.onError: \(error, privacy: .public)")
delegate?.onError(error: error, isRecoverable: true)
}
}

View File

@@ -2,9 +2,7 @@
// Swift bridge generated code triggers this below
#![allow(improper_ctypes, non_camel_case_types)]
use firezone_client_connlib::{
Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses,
};
use firezone_client_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses};
use std::{net::Ipv4Addr, sync::Arc};
#[swift_bridge::bridge]
@@ -23,6 +21,9 @@ mod ffi {
// TODO: Duplicating these enum variants from `libs/common/src/error.rs` is
// brittle/noisy/tedious
enum SwiftConnlibError {
// `swift-bridge` doesn't seem to support `Option` for Swift function
// arguments...
None,
Io,
Base64DecodeError,
Base64DecodeSliceError,
@@ -46,11 +47,6 @@ mod ffi {
NoMtu,
}
enum SwiftErrorType {
Recoverable,
Fatal,
}
extern "Rust" {
type WrappedSession;
@@ -89,10 +85,10 @@ mod ffi {
fn on_update_resources(&self, resourceList: ResourceList);
#[swift_bridge(swift_name = "onDisconnect")]
fn on_disconnect(&self);
fn on_disconnect(&self, error: SwiftConnlibError);
#[swift_bridge(swift_name = "onError")]
fn on_error(&self, error: SwiftConnlibError, error_type: SwiftErrorType);
fn on_error(&self, error: SwiftConnlibError);
}
}
@@ -130,15 +126,6 @@ impl From<Error> for ffi::SwiftConnlibError {
}
}
impl From<ErrorType> for ffi::SwiftErrorType {
fn from(val: ErrorType) -> Self {
match val {
ErrorType::Recoverable => Self::Recoverable,
ErrorType::Fatal => Self::Fatal,
}
}
}
impl From<ResourceList> for ffi::ResourceList {
fn from(value: ResourceList) -> Self {
Self {
@@ -195,12 +182,16 @@ impl Callbacks for CallbackHandler {
self.0.on_update_resources(resource_list.into())
}
fn on_disconnect(&self) {
self.0.on_disconnect()
fn on_disconnect(&self, error: Option<&Error>) {
self.0.on_disconnect(
error
.map(Into::into)
.unwrap_or(ffi::SwiftConnlibError::None),
)
}
fn on_error(&self, error: &Error, error_type: ErrorType) {
self.0.on_error(error.into(), error_type.into())
fn on_error(&self, error: &Error) {
self.0.on_error(error.into())
}
}
@@ -230,6 +221,6 @@ impl WrappedSession {
}
fn disconnect(&mut self) -> bool {
self.session.disconnect()
self.session.disconnect(None)
}
}

View File

@@ -3,7 +3,7 @@ use clap::Parser;
use std::{net::Ipv4Addr, str::FromStr};
use firezone_client_connlib::{
get_user_agent, Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses,
get_user_agent, Callbacks, Error, ResourceList, Session, TunnelAddresses,
};
use url::Url;
@@ -25,15 +25,12 @@ impl Callbacks for CallbackHandler {
tracing::trace!("Resources updated, current list: {resource_list:?}");
}
fn on_disconnect(&self) {
tracing::trace!("Tunnel disconnected");
fn on_disconnect(&self, error: Option<&Error>) {
tracing::trace!("Tunnel disconnected: {error:?}");
}
fn on_error(&self, error: &Error, error_type: ErrorType) {
match error_type {
ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"),
ErrorType::Fatal => panic!("Encountered fatal error: {error}"),
}
fn on_error(&self, error: &Error) {
tracing::warn!("Encountered recoverable error: {error}");
}
}
@@ -54,7 +51,7 @@ fn main() -> Result<()> {
let mut session = Session::connect(url, secret, CallbackHandler).unwrap();
tracing::info!("Started new session");
session.wait_for_ctrl_c().unwrap();
session.disconnect();
session.disconnect(None);
Ok(())
}

View File

@@ -1,9 +1,7 @@
use anyhow::{Context, Result};
use std::{net::Ipv4Addr, str::FromStr};
use firezone_gateway_connlib::{
Callbacks, Error, ErrorType, ResourceList, Session, TunnelAddresses,
};
use firezone_gateway_connlib::{Callbacks, Error, ResourceList, Session, TunnelAddresses};
use url::Url;
#[derive(Clone)]
@@ -24,15 +22,12 @@ impl Callbacks for CallbackHandler {
tracing::trace!("Resources updated, current list: {resource_list:?}");
}
fn on_disconnect(&self) {
tracing::trace!("Tunnel disconnected");
fn on_disconnect(&self, error: Option<&Error>) {
tracing::trace!("Tunnel disconnected: {error:?}");
}
fn on_error(&self, error: &Error, error_type: ErrorType) {
match error_type {
ErrorType::Recoverable => tracing::warn!("Encountered error: {error}"),
ErrorType::Fatal => panic!("Encountered fatal error: {error}"),
}
fn on_error(&self, error: &Error) {
tracing::warn!("Encountered recoverable error: {error}");
}
}
@@ -46,7 +41,7 @@ fn main() -> Result<()> {
let secret = parse_env_var::<String>(SECRET_ENV_VAR)?;
let mut session = Session::connect(url, secret, CallbackHandler).unwrap();
session.wait_for_ctrl_c().unwrap();
session.disconnect();
session.disconnect(None);
Ok(())
}

View File

@@ -4,7 +4,6 @@ use crate::messages::{Connect, EgressMessages, InitClient, Messages, Relays};
use boringtun::x25519::StaticSecret;
use libs_common::{
control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic},
error_type::ErrorType::{self, Fatal, Recoverable},
messages::{Id, ResourceDescription},
Callbacks, ControlSession, Error, Result,
};
@@ -45,13 +44,13 @@ struct ControlSignaler {
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(mut self, mut receiver: Receiver<MessageResult<Messages>>) {
async fn start(mut self, mut receiver: Receiver<MessageResult<Messages>>) -> Result<()> {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some(msg) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg).await,
Ok(msg) => self.handle_message(msg).await?,
Err(msg_reply) => self.handle_error(msg_reply).await,
}
},
@@ -59,6 +58,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
else => break
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip_all)]
@@ -68,18 +68,17 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
interface,
resources,
}: InitClient,
) {
) -> Result<()> {
if let Err(e) = self.tunnel.set_interface(&interface).await {
tracing::error!("Couldn't initialize interface: {e}");
self.tunnel.callbacks().on_error(&e, Fatal);
return;
Err(e)
} else {
for resource_description in resources {
self.add_resource(resource_description).await?;
}
tracing::info!("Firezoned Started!");
Ok(())
}
for resource_description in resources {
self.add_resource(resource_description).await
}
tracing::info!("Firezoned Started!");
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -101,13 +100,13 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
)
.await
{
self.tunnel.callbacks().on_error(&e, Recoverable);
self.tunnel.callbacks().on_error(&e);
}
}
#[tracing::instrument(level = "trace", skip(self))]
async fn add_resource(&self, resource_description: ResourceDescription) {
self.tunnel.add_resource(resource_description).await;
async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> {
self.tunnel.add_resource(resource_description).await
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -143,27 +142,28 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
.await
{
tunnel.cleanup_connection(resource_id);
tunnel.callbacks().on_error(&err, Recoverable);
tunnel.callbacks().on_error(&err);
}
}
Err(err) => {
tunnel.cleanup_connection(resource_id);
tunnel.callbacks().on_error(&err, Recoverable);
tunnel.callbacks().on_error(&err);
}
}
});
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_message(&mut self, msg: Messages) {
pub(super) async fn handle_message(&mut self, msg: Messages) -> Result<()> {
match msg {
Messages::Init(init) => self.init(init).await,
Messages::Init(init) => self.init(init).await?,
Messages::Relays(connection_details) => self.relays(connection_details),
Messages::Connect(connect) => self.connect(connect).await,
Messages::ResourceAdded(resource) => self.add_resource(resource).await,
Messages::ResourceAdded(resource) => self.add_resource(resource).await?,
Messages::ResourceRemoved(resource) => self.remove_resource(resource.id),
Messages::ResourceUpdated(resource) => self.update_resource(resource),
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -175,7 +175,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
tracing::error!(
"An offline error came back with a reference to a non-valid resource id"
);
self.tunnel.callbacks().on_error(&Error::ControlProtocolError, ErrorType::Recoverable);
self.tunnel.callbacks().on_error(&Error::ControlProtocolError);
return;
};
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
@@ -187,7 +187,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
);
self.tunnel
.callbacks()
.on_error(&Error::ControlProtocolError, ErrorType::Recoverable);
.on_error(&Error::ControlProtocolError);
}
}
}

View File

@@ -18,8 +18,6 @@ pub type Session<CB> = libs_common::Session<
CB,
>;
pub use libs_common::{
error_type::ErrorType, get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses,
};
pub use libs_common::{get_user_agent, Callbacks, Error, ResourceList, TunnelAddresses};
use messages::Messages;
use messages::ReplyMessages;

View File

@@ -29,6 +29,7 @@ boringtun = { workspace = true }
os_info = { version = "3", default-features = false }
rand = { version = "0.8", default-features = false, features = ["std"] }
chrono = { workspace = true }
parking_lot = "0.12"
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
swift-bridge = { workspace = true }

View File

@@ -1,16 +0,0 @@
//! Module that contains the Error-Type that hints how to handle an error to upper layers.
/// This indicates whether the produced error is something recoverable or fatal.
/// Fata/Recoverable only indicates how to handle the error for the client.
///
/// Any of the errors in [ConnlibError][crate::error::ConnlibError] could be of any [ErrorType] depending the circumstance.
#[derive(Debug, Clone, Copy)]
pub enum ErrorType {
/// Recoverable means that the session can continue
/// e.g. Failed to send an SDP
Recoverable,
/// Fatal error means that the session should stop and start again,
/// generally after user input, such as clicking connect once more.
/// e.g. Max number of retries was reached when trying to connect to the portal.
Fatal,
}

View File

@@ -4,7 +4,6 @@
//! we are using the same version across our own crates.
pub mod error;
pub mod error_type;
mod session;

View File

@@ -1,11 +1,13 @@
use async_trait::async_trait;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use boringtun::x25519::{PublicKey, StaticSecret};
use parking_lot::Mutex;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rand_core::OsRng;
use std::{
marker::PhantomData,
net::{Ipv4Addr, Ipv6Addr},
sync::Arc,
time::Duration,
};
use tokio::{runtime::Runtime, sync::mpsc::Receiver};
@@ -14,7 +16,6 @@ use uuid::Uuid;
use crate::{
control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic},
error_type::ErrorType,
messages::{Key, ResourceDescription, ResourceDescriptionCidr},
Error, Result,
};
@@ -44,7 +45,7 @@ pub trait ControlSession<T, CB: Callbacks> {
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session<T, U, V, R, M, CB: Callbacks> {
runtime: Option<Runtime>,
runtime: Arc<Mutex<Option<Runtime>>>,
callbacks: CB,
_phantom: PhantomData<(T, U, V, R, M)>,
}
@@ -64,7 +65,6 @@ pub struct TunnelAddresses {
pub address6: Ipv6Addr,
}
// Evaluate doing this not static
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Called when the tunnel address is set.
@@ -78,21 +78,20 @@ pub trait Callbacks: Clone + Send + Sync {
/// Called when the resource list changes.
fn on_update_resources(&self, resource_list: ResourceList);
/// Called when the tunnel is disconnected.
fn on_disconnect(&self);
/// Called when there's an error.
///
/// # Parameters
/// - `error`: The actual error that happened.
/// - `error_type`: Whether the error should terminate the session or not.
fn on_error(&self, error: &Error, error_type: ErrorType);
/// If the tunnel disconnected due to a fatal error, `error` is the error
/// that caused the disconnect.
fn on_disconnect(&self, error: Option<&Error>);
/// Called when there's a recoverable error.
fn on_error(&self, error: &Error);
}
macro_rules! fatal_error {
($result:expr, $c:expr) => {
($result:expr, $rt:expr, $cb:expr) => {
match $result {
Ok(res) => res,
Err(e) => {
$c.on_error(&e, ErrorType::Fatal);
Err(err) => {
Self::disconnect_inner($rt, $cb, Some(err));
return;
}
}
@@ -112,6 +111,7 @@ where
/// (Used for the gateways).
pub fn wait_for_ctrl_c(&mut self) -> Result<()> {
self.runtime
.lock()
.as_ref()
.ok_or(Error::NoRuntime)?
.block_on(async {
@@ -138,32 +138,48 @@ where
// Big question here however is how do we get the result? We could block here await the result and spawn a new task.
// but then platforms should know that this function is blocking.
let portal_url = portal_url.try_into().map_err(|_| Error::UriError)?;
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
if cfg!(feature = "mock") {
Self::connect_mock(callbacks.clone());
} else {
Self::connect_inner(&runtime, portal_url, token, callbacks.clone());
}
Ok(Self {
runtime: Some(runtime),
let this = Self {
runtime: Mutex::new(Some(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?,
))
.into(),
callbacks,
_phantom: PhantomData,
})
};
if cfg!(feature = "mock") {
Self::connect_mock(this.callbacks.clone());
} else {
Self::connect_inner(
Arc::clone(&this.runtime),
portal_url.try_into().map_err(|_| Error::UriError)?,
token,
this.callbacks.clone(),
);
}
Ok(this)
}
fn connect_inner(runtime: &Runtime, portal_url: Url, token: String, callbacks: CB) {
runtime.spawn(async move {
fn connect_inner(
runtime: Arc<Mutex<Option<Runtime>>>,
portal_url: Url,
token: String,
callbacks: CB,
) {
let runtime_disconnector = Arc::clone(&runtime);
runtime.lock().as_ref().unwrap().spawn(async move {
let private_key = StaticSecret::random_from_rng(OsRng);
let self_id = uuid::Uuid::new_v4();
let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect();
let connect_url = fatal_error!(get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string(), &name_suffix), callbacks);
let connect_url = fatal_error!(
get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &self_id.to_string(), &name_suffix),
&runtime_disconnector,
&callbacks
);
// This is kinda hacky, the buffer size is 1 so that we make sure that we
@@ -184,7 +200,11 @@ where
// Used to send internal messages
let topic = T::socket_path().to_string();
let internal_sender = connection.sender_with_topic(topic.clone());
fatal_error!(T::start(private_key, control_plane_receiver, internal_sender, callbacks.clone()).await, callbacks);
fatal_error!(
T::start(private_key, control_plane_receiver, internal_sender, callbacks.clone()).await,
&runtime_disconnector,
&callbacks
);
tokio::spawn(async move {
let mut exponential_backoff = ExponentialBackoffBuilder::default().build();
@@ -193,18 +213,15 @@ where
let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await;
if let Some(t) = exponential_backoff.next_backoff() {
tracing::warn!("Error during connection to the portal, retrying in {} seconds", t.as_secs());
match result {
Ok(()) => callbacks.on_error(&tokio_tungstenite::tungstenite::Error::ConnectionClosed.into(), ErrorType::Recoverable),
Err(e) => callbacks.on_error(&e, ErrorType::Recoverable)
}
callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed)));
tokio::time::sleep(t).await;
} else {
tracing::error!("Connection to the portal error, check your internet or the status of the portal.\nDisconnecting interface.");
match result {
Ok(()) => callbacks.on_error(&crate::Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed), ErrorType::Fatal),
Err(e) => callbacks.on_error(&e, ErrorType::Fatal)
}
break;
fatal_error!(
result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))),
&runtime_disconnector,
&callbacks
);
}
}
@@ -251,12 +268,8 @@ where
});
}
/// Cleanup a [Session].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self) -> bool {
self.callbacks.on_disconnect();
fn disconnect_inner(runtime: &Mutex<Option<Runtime>>, callbacks: &CB, error: Option<Error>) {
callbacks.on_disconnect(error.as_ref());
// 1. Close the websocket connection
// 2. Free the device handle (UNIX)
@@ -269,7 +282,15 @@ where
// So always yield and if you spawn a blocking tasks rewrite this.
// Furthermore, we will depend on Drop impls to do the list above so,
// implement them :)
self.runtime = None;
*runtime.lock() = None;
}
/// Cleanup a [Session].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self, error: Option<Error>) -> bool {
Self::disconnect_inner(&self.runtime, &self.callbacks, error);
true
}

View File

@@ -4,7 +4,6 @@ use boringtun::x25519::StaticSecret;
use firezone_tunnel::{ControlSignal, Tunnel};
use libs_common::{
control::{MessageResult, PhoenixSenderWithTopic},
error_type::ErrorType::{Fatal, Recoverable},
messages::ResourceDescription,
Callbacks, ControlSession, Result,
};
@@ -36,13 +35,13 @@ impl ControlSignal for ControlSignaler {
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(mut self, mut receiver: Receiver<MessageResult<IngressMessages>>) {
async fn start(mut self, mut receiver: Receiver<MessageResult<IngressMessages>>) -> Result<()> {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some(msg) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg).await,
Ok(msg) => self.handle_message(msg).await?,
Err(_msg_reply) => todo!(),
}
},
@@ -50,18 +49,19 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
else => break
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip_all)]
async fn init(&mut self, init: InitGateway) {
async fn init(&mut self, init: InitGateway) -> Result<()> {
if let Err(e) = self.tunnel.set_interface(&init.interface).await {
tracing::error!("Couldn't initialize interface: {e}");
self.tunnel.callbacks().on_error(&e, Fatal);
return;
Err(e)
} else {
// TODO: Enable masquerading here.
tracing::info!("Firezoned Started!");
Ok(())
}
// TODO: Enable masquerading here.
tracing::info!("Firezoned Started!");
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -89,12 +89,12 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
.await
{
tunnel.cleanup_connection(connection_request.device.id);
tunnel.callbacks().on_error(&err, Recoverable);
tunnel.callbacks().on_error(&err);
}
}
Err(err) => {
tunnel.cleanup_connection(connection_request.device.id);
tunnel.callbacks().on_error(&err, Recoverable);
tunnel.callbacks().on_error(&err);
}
}
});
@@ -106,9 +106,9 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_message(&mut self, msg: IngressMessages) {
pub(super) async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> {
match msg {
IngressMessages::Init(init) => self.init(init).await,
IngressMessages::Init(init) => self.init(init).await?,
IngressMessages::RequestConnection(connection_request) => {
self.connection_request(connection_request)
}
@@ -116,6 +116,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
IngressMessages::RemoveResource(_) => todo!(),
IngressMessages::UpdateResource(_) => todo!(),
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]

View File

@@ -19,4 +19,4 @@ pub type Session<CB> = libs_common::Session<
CB,
>;
pub use libs_common::{error_type::ErrorType, Callbacks, Error, ResourceList, TunnelAddresses};
pub use libs_common::{Callbacks, Error, ResourceList, TunnelAddresses};

View File

@@ -6,7 +6,6 @@ use chrono::{DateTime, Utc};
use std::sync::Arc;
use libs_common::{
error_type::ErrorType::Recoverable,
messages::{Id, Key, Relay, RequestConnection},
Callbacks, Error, Result,
};
@@ -165,7 +164,7 @@ where
let Some(gateway_public_key) = tunnel.gateway_public_keys.lock().remove(&resource_id) else {
tunnel.cleanup_connection(resource_id);
tracing::warn!("Opened ICE channel with gateway without ever receiving public key");
tunnel.callbacks.on_error(&Error::ControlProtocolError, Recoverable);
tunnel.callbacks.on_error(&Error::ControlProtocolError);
return;
};
let peer_config = PeerConfig {
@@ -177,7 +176,7 @@ where
if let Err(e) = tunnel.handle_channel_open(d, index, peer_config, None, resource_id).await {
tracing::error!("Couldn't establish wireguard link after channel was opened: {e}");
tunnel.callbacks.on_error(&e, Recoverable);
tunnel.callbacks.on_error(&e);
tunnel.cleanup_connection(resource_id);
}
tunnel.awaiting_connection.lock().remove(&resource_id);
@@ -283,7 +282,7 @@ where
for ip in &peer.ips {
if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await
{
tunnel.callbacks.on_error(&e, Recoverable);
tunnel.callbacks.on_error(&e);
}
}
}
@@ -298,7 +297,7 @@ where
)
.await
{
tunnel.callbacks.on_error(&e, Recoverable);
tunnel.callbacks.on_error(&e);
tracing::error!(
"Couldn't establish wireguard link after opening channel: {e}"
);
@@ -308,7 +307,7 @@ where
if let Some(conn) = conn {
if let Err(e) = conn.close().await {
tracing::error!("Problem while trying to close channel: {e:?}");
tunnel.callbacks().on_error(&e.into(), Recoverable);
tunnel.callbacks().on_error(&e.into());
}
}
}

View File

@@ -11,10 +11,7 @@ use boringtun::{
};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use libs_common::{
error_type::ErrorType::{Fatal, Recoverable},
Callbacks,
};
use libs_common::Callbacks;
use async_trait::async_trait;
use bytes::Bytes;
@@ -215,26 +212,20 @@ where
/// Once added, when a packet for the resource is intercepted a new data channel will be created
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(&self, resource_description: ResourceDescription) {
pub async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> {
{
let mut iface_config = self.iface_config.lock().await;
for ip in resource_description.ips() {
if let Err(err) = iface_config.add_route(&ip, self.callbacks()).await {
self.callbacks.on_error(&err, Fatal);
}
iface_config.add_route(&ip, self.callbacks()).await?;
}
}
let resource_list = {
let mut resources = self.resources.write();
resources.insert(resource_description);
resources.resource_list()
resources.resource_list()?
};
match resource_list {
Ok(resource_list) => {
self.callbacks.on_update_resources(resource_list);
}
Err(err) => self.callbacks.on_error(&err.into(), Fatal),
}
self.callbacks.on_update_resources(resource_list);
Ok(())
}
/// Sets the interface configuration and starts background tasks.
@@ -440,13 +431,13 @@ where
async fn write4_device_infallible(&self, packet: &[u8]) {
if let Err(e) = self.device_channel.write4(packet).await {
self.callbacks.on_error(&e.into(), Recoverable);
self.callbacks.on_error(&e.into());
}
}
async fn write6_device_infallible(&self, packet: &[u8]) {
if let Err(e) = self.device_channel.write6(packet).await {
self.callbacks.on_error(&e.into(), Recoverable);
self.callbacks.on_error(&e.into());
}
}
@@ -476,13 +467,13 @@ where
Ok(res) => res,
Err(err) => {
tracing::error!("Couldn't read packet from interface: {err}");
dev.callbacks.on_error(&err.into(), Recoverable);
dev.callbacks.on_error(&err.into());
continue;
}
},
Err(err) => {
tracing::error!("Couldn't obtain iface mtu: {err}");
dev.callbacks.on_error(&err, Recoverable);
dev.callbacks.on_error(&err);
continue;
}
}
@@ -525,7 +516,7 @@ where
// Not a deadlock because this is a different task
dev.awaiting_connection.lock().remove(&id);
tracing::error!("couldn't start protocol for new connection to resource: {e}");
dev.callbacks.on_error(&e, Recoverable);
dev.callbacks.on_error(&e);
}
});
}
@@ -544,7 +535,7 @@ where
}
TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error for resource corresponding to {dst_addr}", error = ?e);
dev.callbacks.on_error(&e.into(), Recoverable);
dev.callbacks.on_error(&e.into());
}
TunnResult::WriteToNetwork(packet) => {
tracing::trace!("writing iface packet to peer: {dst_addr}");
@@ -565,11 +556,11 @@ where
tracing::error!(
"Problem while trying to close channel: {e:?}"
);
dev.callbacks().on_error(&e.into(), Recoverable);
dev.callbacks().on_error(&e.into());
}
}
}
dev.callbacks.on_error(&e.into(), Recoverable);
dev.callbacks.on_error(&e.into());
}
}
_ => panic!("Unexpected result from encapsulate"),

View File

@@ -5,7 +5,7 @@ use bytes::Bytes;
use chrono::{DateTime, Utc};
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use libs_common::{error_type::ErrorType, messages::Id, Callbacks, Result};
use libs_common::{messages::Id, Callbacks, Result};
use parking_lot::Mutex;
use webrtc::data::data_channel::DataChannel;
@@ -24,7 +24,7 @@ impl Peer {
pub(crate) async fn send_infallible<CB: Callbacks>(&self, data: &[u8], callbacks: &CB) {
if let Err(e) = self.channel.write(&Bytes::copy_from_slice(data)).await {
tracing::error!("Couldn't send packet to connected peer: {e}");
callbacks.on_error(&e.into(), ErrorType::Recoverable);
callbacks.on_error(&e.into());
}
}