Duplicate Session for client and gateway (#2169)

At the moment, there is a lot of indirection between tunnel, session,
control planes etc. I think this is a lot easier to understand if we
don't have as many type parameters and instead, create one `Session`
type per "kind" of deployment: Clients and gateways.

This is an initial start, there is now some duplication between gateways
and clients. I'd recommend patch-by-patch review.

I originally started this to do
https://github.com/firezone/firezone/issues/2158 but that is not
possible until we have _concrete_ message types within each `Session`,
hence I am sending this PR first.
This commit is contained in:
Thomas Eizinger
2023-09-28 12:57:44 +10:00
committed by GitHub
parent 21afdf0a9a
commit 0ceecc0c0e
13 changed files with 762 additions and 685 deletions

26
rust/Cargo.lock generated
View File

@@ -509,14 +509,12 @@ dependencies = [
"hmac",
"ip_network",
"ip_network_table",
"jni 0.19.0",
"libc",
"nix 0.25.1",
"parking_lot",
"rand_core",
"ring",
"tracing",
"tracing-subscriber",
"untrusted 0.9.0",
"x25519-dalek",
]
@@ -739,7 +737,7 @@ version = "0.1.6"
dependencies = [
"firezone-client-connlib",
"ip_network",
"jni 0.21.1",
"jni",
"log",
"serde_json",
"thiserror",
@@ -1153,14 +1151,17 @@ dependencies = [
"chrono",
"firezone-tunnel",
"libs-common",
"rand",
"serde",
"serde_json",
"tokio",
"tokio-tungstenite",
"tracing",
"tracing-android",
"tracing-appender",
"tracing-stackdriver",
"tracing-subscriber",
"url",
"webrtc",
]
@@ -1174,10 +1175,13 @@ dependencies = [
"chrono",
"firezone-tunnel",
"libs-common",
"rand",
"serde",
"serde_json",
"tokio",
"tokio-tungstenite",
"tracing",
"url",
"webrtc",
]
@@ -1743,20 +1747,6 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "jni"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
dependencies = [
"cesu8",
"combine",
"jni-sys",
"log",
"thiserror",
"walkdir",
]
[[package]]
name = "jni"
version = "0.21.1"
@@ -1838,8 +1828,6 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4"
name = "libs-common"
version = "0.1.0"
dependencies = [
"async-trait",
"backoff",
"base64 0.21.4",
"boringtun",
"chrono",

View File

@@ -19,6 +19,9 @@ serde = { version = "1.0", default-features = false, features = ["std", "derive"
boringtun = { workspace = true }
backoff = { workspace = true }
webrtc = "0.8"
url = { version = "2.4.1", default-features = false }
rand = { version = "0.8", default-features = false, features = ["std"] }
tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
[target.'cfg(target_os = "android")'.dependencies]
tracing = { workspace = true, features = ["std", "attributes"] }

View File

@@ -1,15 +1,13 @@
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use crate::messages::{
BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages,
GatewayIceCandidates, InitClient, Messages,
};
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use boringtun::x25519::StaticSecret;
use libs_common::{
control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic, Reference},
control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference},
messages::{GatewayId, ResourceDescription, ResourceId},
Callbacks, ControlSession,
Callbacks,
Error::{self, ControlProtocolError},
Result,
};
@@ -17,7 +15,7 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
use async_trait::async_trait;
use firezone_tunnel::{ConnId, ControlSignal, Request, Tunnel};
use tokio::sync::{mpsc::Receiver, Mutex};
use tokio::sync::Mutex;
#[async_trait]
impl ControlSignal for ControlSignaler {
@@ -67,42 +65,20 @@ impl ControlSignal for ControlSignaler {
}
}
/// Implementation of [ControlSession] for clients.
pub struct ControlPlane<CB: Callbacks> {
tunnel: Arc<Tunnel<ControlSignaler, CB>>,
control_signaler: ControlSignaler,
tunnel_init: Mutex<bool>,
pub tunnel: Arc<Tunnel<ControlSignaler, CB>>,
pub control_signaler: ControlSignaler,
pub tunnel_init: Mutex<bool>,
}
#[derive(Clone)]
struct ControlSignaler {
control_signal: PhoenixSenderWithTopic,
pub struct ControlSignaler {
pub control_signal: PhoenixSenderWithTopic,
}
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(
mut self,
mut receiver: Receiver<(MessageResult<Messages>, Option<Reference>)>,
) -> Result<()> {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some((msg, reference)) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg, reference).await?,
Err(err) => self.handle_error(err, reference).await,
}
},
_ = interval.tick() => self.stats_event().await,
else => break
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn init(
pub async fn init(
&mut self,
InitClient {
interface,
@@ -131,7 +107,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
async fn connect(
pub async fn connect(
&mut self,
Connect {
gateway_rtc_session_description,
@@ -154,7 +130,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
async fn add_resource(&self, resource_description: ResourceDescription) {
pub async fn add_resource(&self, resource_description: ResourceDescription) {
if let Err(e) = self.tunnel.add_resource(resource_description).await {
tracing::error!(message = "Can't add resource", error = ?e);
let _ = self.tunnel.callbacks().on_error(&e);
@@ -248,7 +224,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_message(
pub async fn handle_message(
&mut self,
msg: Messages,
reference: Option<Reference>,
@@ -268,11 +244,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_error(
&mut self,
reply_error: ErrorReply,
reference: Option<Reference>,
) {
pub async fn handle_error(&mut self, reply_error: ErrorReply, reference: Option<Reference>) {
if matches!(reply_error.error, ErrorInfo::Offline) {
match reference {
Some(reference) => {
@@ -302,39 +274,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
}
pub(super) async fn stats_event(&mut self) {
pub async fn stats_event(&mut self) {
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
}
}
#[async_trait]
impl<CB: Callbacks + 'static> ControlSession<Messages, CB> for ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
async fn start(
private_key: StaticSecret,
receiver: Receiver<(MessageResult<Messages>, Option<Reference>)>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()> {
let control_signaler = ControlSignaler { control_signal };
let tunnel = Arc::new(Tunnel::new(private_key, control_signaler.clone(), callbacks).await?);
let control_plane = ControlPlane {
tunnel,
control_signaler,
tunnel_init: Mutex::new(false),
};
tokio::spawn(async move { control_plane.start(receiver).await });
Ok(())
}
fn socket_path() -> &'static str {
"client"
}
fn retry_strategy() -> ExponentialBackoff {
ExponentialBackoffBuilder::default().build()
}
}

View File

@@ -1,25 +1,246 @@
//! Main connlib library for clients.
pub use libs_common::{get_device_id, messages::ResourceDescription};
pub use libs_common::{Callbacks, Error};
pub use tracing_appender::non_blocking::WorkerGuard;
use crate::control::ControlSignaler;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use boringtun::x25519::{PublicKey, StaticSecret};
use control::ControlPlane;
use messages::EgressMessages;
use firezone_tunnel::Tunnel;
use libs_common::{
control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result,
};
use messages::IngressMessages;
use messages::Messages;
use messages::ReplyMessages;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::sync::Arc;
use std::time::Duration;
use tokio::{runtime::Runtime, sync::Mutex};
use url::Url;
mod control;
pub mod file_logger;
mod messages;
/// Session type for clients.
///
/// For more information see libs_common docs on [Session][libs_common::Session].
pub type Session<CB> = libs_common::Session<
ControlPlane<CB>,
IngressMessages,
EgressMessages,
ReplyMessages,
Messages,
CB,
>;
struct StopRuntime;
pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error};
use messages::Messages;
use messages::ReplyMessages;
pub use tracing_appender::non_blocking::WorkerGuard;
/// A session is the entry-point for connlib, maintains the runtime and the tunnel.
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session<CB: Callbacks> {
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
pub callbacks: CallbackErrorFacade<CB>,
}
macro_rules! fatal_error {
($result:expr, $rt:expr, $cb:expr) => {
match $result {
Ok(res) => res,
Err(err) => {
Self::disconnect_inner($rt, $cb, Some(err));
return;
}
}
};
}
impl<CB> Session<CB>
where
CB: Callbacks + 'static,
{
/// Starts a session in the background.
///
/// This will:
/// 1. Create and start a tokio runtime
/// 2. Connect to the control plane to the portal
/// 3. Start the tunnel in the background and forward control plane messages to it.
///
/// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced.
///
/// On a fatal error you should call `[Session::disconnect]` and start a new one.
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
portal_url: impl TryInto<Url>,
token: String,
device_id: String,
callbacks: CB,
) -> Result<Self> {
// TODO: We could use tokio::runtime::current() to get the current runtime
// which could work with swift-rust that already runs a runtime. But IDK if that will work
// in all platforms, a couple of new threads shouldn't bother none.
// Big question here however is how do we get the result? We could block here await the result and spawn a new task.
// but then platforms should know that this function is blocking.
let callbacks = CallbackErrorFacade(callbacks);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let this = Self {
runtime_stopper: tx.clone(),
callbacks,
};
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
{
let callbacks = this.callbacks.clone();
let default_panic_hook = std::panic::take_hook();
std::panic::set_hook(Box::new({
let tx = tx.clone();
move |info| {
let tx = tx.clone();
let err = info
.payload()
.downcast_ref::<&str>()
.map(|s| Error::Panic(s.to_string()))
.unwrap_or(Error::PanicNonStringPayload);
Self::disconnect_inner(tx, &callbacks, Some(err));
default_panic_hook(info);
}
}));
}
Self::connect_inner(
&runtime,
tx,
portal_url.try_into().map_err(|_| Error::UriError)?,
token,
device_id,
this.callbacks.clone(),
);
std::thread::spawn(move || {
rx.blocking_recv();
runtime.shutdown_background();
});
Ok(this)
}
fn connect_inner(
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
portal_url: Url,
token: String,
device_id: String,
callbacks: CallbackErrorFacade<CB>,
) {
runtime.spawn(async move {
let private_key = StaticSecret::random_from_rng(rand::rngs::OsRng);
let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect();
let external_id = sha256(device_id);
let connect_url = fatal_error!(
get_websocket_path(portal_url, token, "client", &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix),
runtime_stopper,
&callbacks
);
// This is kinda hacky, the buffer size is 1 so that we make sure that we
// process one message at a time, blocking if a previous message haven't been processed
// to force queue ordering.
let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(connect_url, move |msg, reference| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!("Received message: {msg:?}");
if let Err(e) = control_plane_sender.send((msg, reference)).await {
tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up.");
}
}
});
let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("client".to_owned()) };
let tunnel = fatal_error!(
Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await,
runtime_stopper,
&callbacks
);
let mut control_plane = ControlPlane {
tunnel: Arc::new(tunnel),
control_signaler,
tunnel_init: Mutex::new(false),
};
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some((msg, reference)) = control_plane_receiver.recv() => {
match msg {
Ok(msg) => control_plane.handle_message(msg, reference).await?,
Err(err) => control_plane.handle_error(err, reference).await,
}
},
_ = interval.tick() => control_plane.stats_event().await,
else => break
}
}
Result::Ok(())
});
tokio::spawn(async move {
let mut exponential_backoff = ExponentialBackoffBuilder::default().build();
loop {
// `connection.start` calls the callback only after connecting
tracing::debug!("Attempting connection to portal...");
let result = connection.start(vec!["client".to_owned()], || exponential_backoff.reset()).await;
tracing::warn!("Disconnected from the portal");
if let Err(e) = &result {
tracing::warn!(error = ?e, "Portal connection error");
}
if let Some(t) = exponential_backoff.next_backoff() {
tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs());
let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed)));
tokio::time::sleep(t).await;
} else {
tracing::error!("Connection to portal failed, giving up");
fatal_error!(
result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))),
runtime_stopper,
&callbacks
);
}
}
});
});
}
fn disconnect_inner(
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
callbacks: &CallbackErrorFacade<CB>,
error: Option<Error>,
) {
// 1. Close the websocket connection
// 2. Free the device handle (Linux)
// 3. Close the file descriptor (Linux/Android)
// 4. Remove the mapping
// The way we cleanup the tasks is we drop the runtime
// this means we don't need to keep track of different tasks
// but if any of the tasks never yields this will block forever!
// So always yield and if you spawn a blocking tasks rewrite this.
// Furthermore, we will depend on Drop impls to do the list above so,
// implement them :)
// if there's no receiver the runtime is already stopped
// there's an edge case where this is called before the thread is listening for stop threads.
// but I believe in that case the channel will be in a signaled state achieving the same result
if let Err(err) = runtime_stopper.try_send(StopRuntime) {
tracing::error!("Couldn't stop runtime: {err}");
}
let _ = callbacks.on_disconnect(error.as_ref());
}
/// Cleanup a [Session].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self, error: Option<Error>) {
Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error)
}
}

View File

@@ -6,33 +6,30 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
mock = []
jni-bindings = ["boringtun/jni-bindings"]
[dependencies]
base64 = { version = "0.21", default-features = false, features = ["std"] }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
boringtun = { workspace = true }
chrono = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }
tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
webrtc = { version = "0.8" }
uuid = { version = "1.4", default-features = false, features = ["std", "v4", "serde"] }
thiserror = { version = "1.0", default-features = false }
tracing = { workspace = true }
serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.32", default-features = false, features = ["rt", "rt-multi-thread"]}
url = { version = "2.4.1", default-features = false }
rand_core = { version = "0.6.4", default-features = false, features = ["std"] }
async-trait = { version = "0.1", default-features = false }
backoff = { workspace = true }
ip_network = { version = "0.4", default-features = false, features = ["serde"] }
boringtun = { workspace = true }
os_info = { version = "3", default-features = false }
rand = { version = "0.8", default-features = false, features = ["std"] }
chrono = { workspace = true }
parking_lot = "0.12"
ring = "0.16"
rand = { version = "0.8", default-features = false, features = ["std"] }
rand_core = { version = "0.6.4", default-features = false, features = ["std"] }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
serde_json = { version = "1.0", default-features = false, features = ["std"] }
thiserror = { version = "1.0", default-features = false }
tokio = { version = "1.32", default-features = false, features = ["rt", "rt-multi-thread"]}
tokio-stream = { version = "0.1", features = ["time"] }
tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
tracing = { workspace = true }
tracing-appender = "0.2"
url = { version = "2.4.1", default-features = false }
uuid = { version = "1.4", default-features = false, features = ["std", "v4", "serde"] }
webrtc = { version = "0.8" }
ring = "0.16"
# Needed for Android logging until tracing is working
log = "0.4"

View File

@@ -0,0 +1,66 @@
use crate::messages::ResourceDescription;
use ip_network::IpNetwork;
use std::error::Error;
use std::fmt::{Debug, Display};
use std::net::{Ipv4Addr, Ipv6Addr};
// Avoids having to map types for Windows
type RawFd = i32;
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Error returned when a callback fails.
type Error: Debug + Display + Error;
/// Called when the tunnel address is set.
fn on_set_interface_config(
&self,
_: Ipv4Addr,
_: Ipv6Addr,
_: Ipv4Addr,
_: String,
) -> Result<RawFd, Self::Error> {
Ok(-1)
}
/// Called when the tunnel is connected.
fn on_tunnel_ready(&self) -> Result<(), Self::Error> {
tracing::trace!("tunnel_connected");
Ok(())
}
/// Called when when a route is added.
fn on_add_route(&self, _: IpNetwork) -> Result<(), Self::Error> {
Ok(())
}
/// Called when when a route is removed.
fn on_remove_route(&self, _: IpNetwork) -> Result<(), Self::Error> {
Ok(())
}
/// Called when the resource list changes.
fn on_update_resources(
&self,
resource_list: Vec<ResourceDescription>,
) -> Result<(), Self::Error> {
tracing::trace!(?resource_list, "resource_updated");
Ok(())
}
/// Called when the tunnel is disconnected.
///
/// If the tunnel disconnected due to a fatal error, `error` is the error
/// that caused the disconnect.
fn on_disconnect(&self, error: Option<&crate::Error>) -> Result<(), Self::Error> {
tracing::trace!(error = ?error, "tunnel_disconnected");
// Note that we can't panic here, since we already hooked the panic to this function.
std::process::exit(0);
}
/// Called when there's a recoverable error.
fn on_error(&self, error: &crate::Error) -> Result<(), Self::Error> {
tracing::warn!(error = ?error);
Ok(())
}
}

View File

@@ -0,0 +1,96 @@
use crate::messages::ResourceDescription;
use crate::{Callbacks, Error, Result};
use ip_network::IpNetwork;
use std::net::{Ipv4Addr, Ipv6Addr};
// Avoids having to map types for Windows
type RawFd = i32;
#[derive(Clone)]
pub struct CallbackErrorFacade<CB>(pub CB);
impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
type Error = Error;
fn on_set_interface_config(
&self,
tunnel_address_v4: Ipv4Addr,
tunnel_address_v6: Ipv6Addr,
dns_address: Ipv4Addr,
dns_fallback_strategy: String,
) -> Result<RawFd> {
let result = self
.0
.on_set_interface_config(
tunnel_address_v4,
tunnel_address_v6,
dns_address,
dns_fallback_strategy,
)
.map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_tunnel_ready(&self) -> Result<()> {
let result = self
.0
.on_tunnel_ready()
.map_err(|err| Error::OnTunnelReadyFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_add_route(&self, route: IpNetwork) -> Result<()> {
let result = self
.0
.on_add_route(route)
.map_err(|err| Error::OnAddRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_remove_route(&self, route: IpNetwork) -> Result<()> {
let result = self
.0
.on_remove_route(route)
.map_err(|err| Error::OnRemoveRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) -> Result<()> {
let result = self
.0
.on_update_resources(resource_list)
.map_err(|err| Error::OnUpdateResourcesFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_disconnect(&self, error: Option<&Error>) -> Result<()> {
if let Err(err) = self.0.on_disconnect(error) {
tracing::error!("`on_disconnect` failed: {err}");
}
// There's nothing we can really do if `on_disconnect` fails.
Ok(())
}
fn on_error(&self, error: &Error) -> Result<()> {
if let Err(err) = self.0.on_error(error) {
tracing::error!("`on_error` failed: {err}");
}
// There's nothing we really want to do if `on_error` fails.
Ok(())
}
}

View File

@@ -3,17 +3,23 @@
//! This includes types provided by external crates, i.e. [boringtun] to make sure that
//! we are using the same version across our own crates.
pub mod error;
mod session;
mod callbacks;
mod callbacks_error_facade;
pub mod control;
pub mod error;
pub mod messages;
pub use callbacks::Callbacks;
pub use callbacks_error_facade::CallbackErrorFacade;
pub use error::ConnlibError as Error;
pub use error::Result;
pub use session::{CallbackErrorFacade, Callbacks, ControlSession, Session, DNS_SENTINEL};
use messages::Key;
use ring::digest::{Context, SHA256};
use std::net::Ipv4Addr;
use url::Url;
pub const DNS_SENTINEL: Ipv4Addr = Ipv4Addr::new(100, 100, 111, 1);
const VERSION: &str = env!("CARGO_PKG_VERSION");
const LIB_NAME: &str = "connlib";
@@ -47,3 +53,55 @@ pub fn get_device_id() -> String {
uuid::Uuid::new_v4().to_string()
}
}
pub fn set_ws_scheme(url: &mut Url) -> Result<()> {
let scheme = match url.scheme() {
"http" | "ws" => "ws",
"https" | "wss" => "wss",
_ => return Err(Error::UriScheme),
};
url.set_scheme(scheme)
.expect("Developer error: the match before this should make sure we can set this");
Ok(())
}
pub fn sha256(input: String) -> String {
let mut ctx = Context::new(&SHA256);
ctx.update(input.as_bytes());
let digest = ctx.finish();
digest
.as_ref()
.iter()
.map(|b| format!("{:02x}", b))
.collect()
}
pub fn get_websocket_path(
mut url: Url,
secret: String,
mode: &str,
public_key: &Key,
external_id: &str,
name_suffix: &str,
) -> Result<Url> {
set_ws_scheme(&mut url)?;
{
let mut paths = url.path_segments_mut().map_err(|_| Error::UriError)?;
paths.pop_if_empty();
paths.push(mode);
paths.push("websocket");
}
{
let mut query_pairs = url.query_pairs_mut();
query_pairs.clear();
query_pairs.append_pair("token", &secret);
query_pairs.append_pair("public_key", &public_key.to_string());
query_pairs.append_pair("external_id", external_id);
query_pairs.append_pair("name_suffix", name_suffix);
}
Ok(url)
}

View File

@@ -1,456 +0,0 @@
use async_trait::async_trait;
use backoff::{backoff::Backoff, ExponentialBackoff};
use boringtun::x25519::{PublicKey, StaticSecret};
use ip_network::IpNetwork;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rand_core::OsRng;
use ring::digest::{Context, SHA256};
use std::{
error::Error as StdError,
fmt::{Debug, Display},
marker::PhantomData,
net::{Ipv4Addr, Ipv6Addr},
result::Result as StdResult,
};
use tokio::{runtime::Runtime, sync::mpsc::Receiver};
use url::Url;
use crate::{
control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic, Reference},
messages::{Key, ResourceDescription},
Error, Result,
};
pub const DNS_SENTINEL: Ipv4Addr = Ipv4Addr::new(100, 100, 111, 1);
// Avoids having to map types for Windows
type RawFd = i32;
struct StopRuntime;
// TODO: Not the most tidy trait for a control-plane.
/// Trait that represents a control-plane.
#[async_trait]
pub trait ControlSession<T, CB: Callbacks> {
/// Start control-plane with the given private-key in the background.
async fn start(
private_key: StaticSecret,
receiver: Receiver<(MessageResult<T>, Option<Reference>)>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()>;
/// Either "gateway" or "client" used to get the control-plane URL.
fn socket_path() -> &'static str;
/// Retry strategy in case of disconnection for the session.
fn retry_strategy() -> ExponentialBackoff;
}
// TODO: Currently I'm using Session for both gateway and clients
// however, gateway could use the runtime directly and could make things easier
// so revisit this.
/// A session is the entry-point for connlib, maintains the runtime and the tunnel.
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session<T, U, V, R, M, CB: Callbacks> {
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
pub callbacks: CallbackErrorFacade<CB>,
_phantom: PhantomData<(T, U, V, R, M)>,
}
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Error returned when a callback fails.
type Error: Debug + Display + StdError;
/// Called when the tunnel address is set.
fn on_set_interface_config(
&self,
_: Ipv4Addr,
_: Ipv6Addr,
_: Ipv4Addr,
_: String,
) -> StdResult<RawFd, Self::Error> {
Ok(-1)
}
/// Called when the tunnel is connected.
fn on_tunnel_ready(&self) -> StdResult<(), Self::Error> {
tracing::trace!("tunnel_connected");
Ok(())
}
/// Called when when a route is added.
fn on_add_route(&self, _: IpNetwork) -> StdResult<(), Self::Error> {
Ok(())
}
/// Called when when a route is removed.
fn on_remove_route(&self, _: IpNetwork) -> StdResult<(), Self::Error> {
Ok(())
}
/// Called when the resource list changes.
fn on_update_resources(
&self,
resource_list: Vec<ResourceDescription>,
) -> StdResult<(), Self::Error> {
tracing::trace!(?resource_list, "resource_updated");
Ok(())
}
/// Called when the tunnel is disconnected.
///
/// If the tunnel disconnected due to a fatal error, `error` is the error
/// that caused the disconnect.
fn on_disconnect(&self, error: Option<&Error>) -> StdResult<(), Self::Error> {
tracing::trace!(error = ?error, "tunnel_disconnected");
// Note that we can't panic here, since we already hooked the panic to this function.
std::process::exit(0);
}
/// Called when there's a recoverable error.
fn on_error(&self, error: &Error) -> StdResult<(), Self::Error> {
tracing::warn!(error = ?error);
Ok(())
}
}
#[derive(Clone)]
pub struct CallbackErrorFacade<CB: Callbacks>(pub CB);
impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
type Error = Error;
fn on_set_interface_config(
&self,
tunnel_address_v4: Ipv4Addr,
tunnel_address_v6: Ipv6Addr,
dns_address: Ipv4Addr,
dns_fallback_strategy: String,
) -> Result<RawFd> {
let result = self
.0
.on_set_interface_config(
tunnel_address_v4,
tunnel_address_v6,
dns_address,
dns_fallback_strategy,
)
.map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_tunnel_ready(&self) -> Result<()> {
let result = self
.0
.on_tunnel_ready()
.map_err(|err| Error::OnTunnelReadyFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_add_route(&self, route: IpNetwork) -> Result<()> {
let result = self
.0
.on_add_route(route)
.map_err(|err| Error::OnAddRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_remove_route(&self, route: IpNetwork) -> Result<()> {
let result = self
.0
.on_remove_route(route)
.map_err(|err| Error::OnRemoveRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) -> Result<()> {
let result = self
.0
.on_update_resources(resource_list)
.map_err(|err| Error::OnUpdateResourcesFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
}
result
}
fn on_disconnect(&self, error: Option<&Error>) -> Result<()> {
if let Err(err) = self.0.on_disconnect(error) {
tracing::error!("`on_disconnect` failed: {err}");
}
// There's nothing we can really do if `on_disconnect` fails.
Ok(())
}
fn on_error(&self, error: &Error) -> Result<()> {
if let Err(err) = self.0.on_error(error) {
tracing::error!("`on_error` failed: {err}");
}
// There's nothing we really want to do if `on_error` fails.
Ok(())
}
}
macro_rules! fatal_error {
($result:expr, $rt:expr, $cb:expr) => {
match $result {
Ok(res) => res,
Err(err) => {
Self::disconnect_inner($rt, $cb, Some(err));
return;
}
}
};
}
impl<T, U, V, R, M, CB> Session<T, U, V, R, M, CB>
where
T: ControlSession<M, CB>,
U: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static,
R: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static,
V: serde::Serialize + Send + 'static,
M: From<U> + From<R> + Send + 'static + std::fmt::Debug,
CB: Callbacks + 'static,
{
/// Starts a session in the background.
///
/// This will:
/// 1. Create and start a tokio runtime
/// 2. Connect to the control plane to the portal
/// 3. Start the tunnel in the background and forward control plane messages to it.
///
/// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced.
///
/// On a fatal error you should call `[Session::disconnect]` and start a new one.
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
portal_url: impl TryInto<Url>,
token: String,
device_id: String,
callbacks: CB,
) -> Result<Self> {
// TODO: We could use tokio::runtime::current() to get the current runtime
// which could work with swift-rust that already runs a runtime. But IDK if that will work
// in all platforms, a couple of new threads shouldn't bother none.
// Big question here however is how do we get the result? We could block here await the result and spawn a new task.
// but then platforms should know that this function is blocking.
let callbacks = CallbackErrorFacade(callbacks);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let this = Self {
runtime_stopper: tx.clone(),
callbacks,
_phantom: PhantomData,
};
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
{
let callbacks = this.callbacks.clone();
let default_panic_hook = std::panic::take_hook();
std::panic::set_hook(Box::new({
let tx = tx.clone();
move |info| {
let tx = tx.clone();
let err = info
.payload()
.downcast_ref::<&str>()
.map(|s| Error::Panic(s.to_string()))
.unwrap_or(Error::PanicNonStringPayload);
Self::disconnect_inner(tx, &callbacks, Some(err));
default_panic_hook(info);
}
}));
}
Self::connect_inner(
&runtime,
tx,
portal_url.try_into().map_err(|_| Error::UriError)?,
token,
device_id,
this.callbacks.clone(),
);
std::thread::spawn(move || {
rx.blocking_recv();
runtime.shutdown_background();
});
Ok(this)
}
fn connect_inner(
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
portal_url: Url,
token: String,
device_id: String,
callbacks: CallbackErrorFacade<CB>,
) {
runtime.spawn(async move {
let private_key = StaticSecret::random_from_rng(OsRng);
let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect();
let external_id = sha256(device_id);
let connect_url = fatal_error!(
get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix),
runtime_stopper,
&callbacks
);
// This is kinda hacky, the buffer size is 1 so that we make sure that we
// process one message at a time, blocking if a previous message haven't been processed
// to force queue ordering.
let (control_plane_sender, control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg, reference| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!("Received message: {msg:?}");
if let Err(e) = control_plane_sender.send((msg, reference)).await {
tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up.");
}
}
});
// Used to send internal messages
let topic = T::socket_path().to_string();
let internal_sender = connection.sender_with_topic(topic.clone());
fatal_error!(
T::start(private_key, control_plane_receiver, internal_sender, callbacks.0.clone()).await,
runtime_stopper,
&callbacks
);
tokio::spawn(async move {
let mut exponential_backoff = T::retry_strategy();
loop {
// `connection.start` calls the callback only after connecting
tracing::debug!("Attempting connection to portal...");
let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await;
tracing::warn!("Disconnected from the portal");
if let Err(e) = &result {
tracing::warn!(error = ?e, "Portal connection error");
}
if let Some(t) = exponential_backoff.next_backoff() {
tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs());
let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed)));
tokio::time::sleep(t).await;
} else {
tracing::error!("Connection to portal failed, giving up");
fatal_error!(
result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))),
runtime_stopper,
&callbacks
);
}
}
});
});
}
fn disconnect_inner(
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
callbacks: &CallbackErrorFacade<CB>,
error: Option<Error>,
) {
// 1. Close the websocket connection
// 2. Free the device handle (Linux)
// 3. Close the file descriptor (Linux/Android)
// 4. Remove the mapping
// The way we cleanup the tasks is we drop the runtime
// this means we don't need to keep track of different tasks
// but if any of the tasks never yields this will block forever!
// So always yield and if you spawn a blocking tasks rewrite this.
// Furthermore, we will depend on Drop impls to do the list above so,
// implement them :)
// if there's no receiver the runtime is already stopped
// there's an edge case where this is called before the thread is listening for stop threads.
// but I believe in that case the channel will be in a signaled state achieving the same result
if let Err(err) = runtime_stopper.try_send(StopRuntime) {
tracing::error!("Couldn't stop runtime: {err}");
}
let _ = callbacks.on_disconnect(error.as_ref());
}
/// Cleanup a [Session].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self, error: Option<Error>) {
Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error)
}
}
fn set_ws_scheme(url: &mut Url) -> Result<()> {
let scheme = match url.scheme() {
"http" | "ws" => "ws",
"https" | "wss" => "wss",
_ => return Err(Error::UriScheme),
};
url.set_scheme(scheme)
.expect("Developer error: the match before this should make sure we can set this");
Ok(())
}
fn sha256(input: String) -> String {
let mut ctx = Context::new(&SHA256);
ctx.update(input.as_bytes());
let digest = ctx.finish();
digest
.as_ref()
.iter()
.map(|b| format!("{:02x}", b))
.collect()
}
fn get_websocket_path(
mut url: Url,
secret: String,
mode: &str,
public_key: &Key,
external_id: &str,
name_suffix: &str,
) -> Result<Url> {
set_ws_scheme(&mut url)?;
{
let mut paths = url.path_segments_mut().map_err(|_| Error::UriError)?;
paths.pop_if_empty();
paths.push(mode);
paths.push("websocket");
}
{
let mut query_pairs = url.query_pairs_mut();
query_pairs.clear();
query_pairs.append_pair("token", &secret);
query_pairs.append_pair("public_key", &public_key.to_string());
query_pairs.append_pair("external_id", external_id);
query_pairs.append_pair("name_suffix", name_suffix);
}
Ok(url)
}

View File

@@ -14,6 +14,9 @@ boringtun = { workspace = true }
chrono = { workspace = true }
backoff = { workspace = true }
webrtc = "0.8"
url = { version = "2.4.1", default-features = false }
rand = { version = "0.8", default-features = false, features = ["std"] }
tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
[dev-dependencies]
serde_json = { version = "1.0", default-features = false, features = ["std"] }

View File

@@ -1,34 +1,26 @@
use std::{sync::Arc, time::Duration};
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use boringtun::x25519::StaticSecret;
use firezone_tunnel::{ConnId, ControlSignal, Tunnel};
use libs_common::{
control::{MessageResult, PhoenixSenderWithTopic, Reference},
messages::{GatewayId, ResourceDescription},
Callbacks, ControlSession,
Error::ControlProtocolError,
Result,
};
use tokio::sync::mpsc::Receiver;
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
use crate::messages::{AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates};
use super::messages::{
ConnectionReady, EgressMessages, IngressMessages, InitGateway, RequestConnection,
};
use crate::messages::{AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates};
use async_trait::async_trait;
use firezone_tunnel::{ConnId, ControlSignal, Tunnel};
use libs_common::Error::ControlProtocolError;
use libs_common::{
control::PhoenixSenderWithTopic,
messages::{GatewayId, ResourceDescription},
Callbacks, Result,
};
use std::sync::Arc;
use webrtc::ice_transport::ice_candidate::RTCIceCandidate;
pub struct ControlPlane<CB: Callbacks> {
tunnel: Arc<Tunnel<ControlSignaler, CB>>,
control_signaler: ControlSignaler,
pub tunnel: Arc<Tunnel<ControlSignaler, CB>>,
pub control_signaler: ControlSignaler,
}
#[derive(Clone)]
struct ControlSignaler {
control_signal: PhoenixSenderWithTopic,
pub struct ControlSignaler {
pub control_signal: PhoenixSenderWithTopic,
}
#[async_trait]
@@ -71,28 +63,7 @@ impl ControlSignal for ControlSignaler {
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(
mut self,
mut receiver: Receiver<(MessageResult<IngressMessages>, Option<Reference>)>,
) -> Result<()> {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some((msg, _)) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg).await?,
Err(_msg_reply) => todo!(),
}
},
_ = interval.tick() => self.stats_event().await,
else => break
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
async fn init(&mut self, init: InitGateway) -> Result<()> {
pub async fn init(&mut self, init: InitGateway) -> Result<()> {
if let Err(e) = self.tunnel.set_interface(&init.interface).await {
tracing::error!("Couldn't initialize interface: {e}");
Err(e)
@@ -104,7 +75,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
fn connection_request(&self, connection_request: RequestConnection) {
pub fn connection_request(&self, connection_request: RequestConnection) {
let tunnel = Arc::clone(&self.tunnel);
let mut control_signaler = self.control_signaler.clone();
tokio::spawn(async move {
@@ -141,7 +112,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
fn allow_access(
pub fn allow_access(
&self,
AllowAccess {
client_id,
@@ -172,7 +143,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> {
pub async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> {
match msg {
IngressMessages::Init(init) => self.init(init).await?,
IngressMessages::RequestConnection(connection_request) => {
@@ -189,40 +160,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
Ok(())
}
pub(super) async fn stats_event(&mut self) {
pub async fn stats_event(&mut self) {
tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats());
}
}
#[async_trait]
impl<CB: Callbacks + 'static> ControlSession<IngressMessages, CB> for ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
async fn start(
private_key: StaticSecret,
receiver: Receiver<(MessageResult<IngressMessages>, Option<Reference>)>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()> {
let control_signaler = ControlSignaler { control_signal };
let tunnel = Arc::new(Tunnel::new(private_key, control_signaler.clone(), callbacks).await?);
let control_plane = ControlPlane {
tunnel,
control_signaler,
};
tokio::spawn(async move { control_plane.start(receiver).await });
Ok(())
}
fn socket_path() -> &'static str {
"gateway"
}
fn retry_strategy() -> ExponentialBackoff {
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(None)
.build()
}
}

View File

@@ -1,22 +1,245 @@
//! Main connlib library for gateway.
pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error};
use crate::control::ControlSignaler;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use boringtun::x25519::{PublicKey, StaticSecret};
use control::ControlPlane;
use messages::EgressMessages;
use firezone_tunnel::Tunnel;
use libs_common::{
control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result,
};
use messages::IngressMessages;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;
use url::Url;
mod control;
mod messages;
/// Session type for gateway.
///
/// For more information see libs_common docs on [Session][libs_common::Session].
// TODO: Still working on gateway messages
pub type Session<CB> = libs_common::Session<
ControlPlane<CB>,
IngressMessages,
EgressMessages,
IngressMessages,
IngressMessages,
CB,
>;
struct StopRuntime;
pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error};
/// A session is the entry-point for connlib, maintains the runtime and the tunnel.
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session<CB: Callbacks> {
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
pub callbacks: CallbackErrorFacade<CB>,
}
macro_rules! fatal_error {
($result:expr, $rt:expr, $cb:expr) => {
match $result {
Ok(res) => res,
Err(err) => {
Self::disconnect_inner($rt, $cb, Some(err));
return;
}
}
};
}
impl<CB> Session<CB>
where
CB: Callbacks + 'static,
{
/// Starts a session in the background.
///
/// This will:
/// 1. Create and start a tokio runtime
/// 2. Connect to the control plane to the portal
/// 3. Start the tunnel in the background and forward control plane messages to it.
///
/// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced.
///
/// On a fatal error you should call `[Session::disconnect]` and start a new one.
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
portal_url: impl TryInto<Url>,
token: String,
device_id: String,
callbacks: CB,
) -> Result<Self> {
// TODO: We could use tokio::runtime::current() to get the current runtime
// which could work with swift-rust that already runs a runtime. But IDK if that will work
// in all platforms, a couple of new threads shouldn't bother none.
// Big question here however is how do we get the result? We could block here await the result and spawn a new task.
// but then platforms should know that this function is blocking.
let callbacks = CallbackErrorFacade(callbacks);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let this = Self {
runtime_stopper: tx.clone(),
callbacks,
};
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
{
let callbacks = this.callbacks.clone();
let default_panic_hook = std::panic::take_hook();
std::panic::set_hook(Box::new({
let tx = tx.clone();
move |info| {
let tx = tx.clone();
let err = info
.payload()
.downcast_ref::<&str>()
.map(|s| Error::Panic(s.to_string()))
.unwrap_or(Error::PanicNonStringPayload);
Self::disconnect_inner(tx, &callbacks, Some(err));
default_panic_hook(info);
}
}));
}
Self::connect_inner(
&runtime,
tx,
portal_url.try_into().map_err(|_| Error::UriError)?,
token,
device_id,
this.callbacks.clone(),
);
std::thread::spawn(move || {
rx.blocking_recv();
runtime.shutdown_background();
});
Ok(this)
}
fn connect_inner(
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
portal_url: Url,
token: String,
device_id: String,
callbacks: CallbackErrorFacade<CB>,
) {
runtime.spawn(async move {
let private_key = StaticSecret::random_from_rng(rand::rngs::OsRng);
let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect();
let external_id = sha256(device_id);
let connect_url = fatal_error!(
get_websocket_path(portal_url, token, "gateway", &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix),
runtime_stopper,
&callbacks
);
// This is kinda hacky, the buffer size is 1 so that we make sure that we
// process one message at a time, blocking if a previous message haven't been processed
// to force queue ordering.
let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, IngressMessages, IngressMessages, IngressMessages>::new(connect_url, move |msg, reference| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!("Received message: {msg:?}");
if let Err(e) = control_plane_sender.send((msg, reference)).await {
tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up.");
}
}
});
// Used to send internal messages
let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("gateway".to_owned()) };
let tunnel = fatal_error!(
Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await,
runtime_stopper,
&callbacks
);
let mut control_plane = ControlPlane {
tunnel: Arc::new(tunnel),
control_signaler,
};
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some((msg, _)) = control_plane_receiver.recv() => {
match msg {
Ok(msg) => control_plane.handle_message(msg).await?,
Err(_msg_reply) => todo!(),
}
},
_ = interval.tick() => control_plane.stats_event().await,
else => break
}
}
Result::Ok(())
});
tokio::spawn(async move {
let mut exponential_backoff = ExponentialBackoffBuilder::default()
.with_max_elapsed_time(None)
.build();
loop {
// `connection.start` calls the callback only after connecting
tracing::debug!("Attempting connection to portal...");
let result = connection.start(vec!["gateway".to_owned()], || exponential_backoff.reset()).await;
tracing::warn!("Disconnected from the portal");
if let Err(e) = &result {
tracing::warn!(error = ?e, "Portal connection error");
}
if let Some(t) = exponential_backoff.next_backoff() {
tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs());
let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed)));
tokio::time::sleep(t).await;
} else {
tracing::error!("Connection to portal failed, giving up");
fatal_error!(
result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))),
runtime_stopper,
&callbacks
);
}
}
});
});
}
fn disconnect_inner(
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
callbacks: &CallbackErrorFacade<CB>,
error: Option<Error>,
) {
// 1. Close the websocket connection
// 2. Free the device handle (Linux)
// 3. Close the file descriptor (Linux/Android)
// 4. Remove the mapping
// The way we cleanup the tasks is we drop the runtime
// this means we don't need to keep track of different tasks
// but if any of the tasks never yields this will block forever!
// So always yield and if you spawn a blocking tasks rewrite this.
// Furthermore, we will depend on Drop impls to do the list above so,
// implement them :)
// if there's no receiver the runtime is already stopped
// there's an edge case where this is called before the thread is listening for stop threads.
// but I believe in that case the channel will be in a signaled state achieving the same result
if let Err(err) = runtime_stopper.try_send(StopRuntime) {
tracing::error!("Couldn't stop runtime: {err}");
}
let _ = callbacks.on_disconnect(error.as_ref());
}
/// Cleanup a [Session].
///
/// For now this just drops the runtime, which should drop all pending tasks.
/// Further cleanup should be done here. (Otherwise we can just drop [Session]).
pub fn disconnect(&mut self, error: Option<Error>) {
Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error)
}
}

View File

@@ -10,7 +10,7 @@ use bytes::Bytes;
use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use libs_common::{messages::Key, Callbacks, Error, DNS_SENTINEL};
use libs_common::{messages::Key, CallbackErrorFacade, Callbacks, Error, DNS_SENTINEL};
use serde::{Deserialize, Serialize};
use async_trait::async_trait;
@@ -35,7 +35,7 @@ use libs_common::{
messages::{
ClientId, GatewayId, Interface as InterfaceConfig, ResourceDescription, ResourceId,
},
CallbackErrorFacade, Result,
Result,
};
use device_channel::{create_iface, DeviceIo, IfaceConfig};