diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ceb0994c6..86ba3fe86 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -509,14 +509,12 @@ dependencies = [ "hmac", "ip_network", "ip_network_table", - "jni 0.19.0", "libc", "nix 0.25.1", "parking_lot", "rand_core", "ring", "tracing", - "tracing-subscriber", "untrusted 0.9.0", "x25519-dalek", ] @@ -739,7 +737,7 @@ version = "0.1.6" dependencies = [ "firezone-client-connlib", "ip_network", - "jni 0.21.1", + "jni", "log", "serde_json", "thiserror", @@ -1153,14 +1151,17 @@ dependencies = [ "chrono", "firezone-tunnel", "libs-common", + "rand", "serde", "serde_json", "tokio", + "tokio-tungstenite", "tracing", "tracing-android", "tracing-appender", "tracing-stackdriver", "tracing-subscriber", + "url", "webrtc", ] @@ -1174,10 +1175,13 @@ dependencies = [ "chrono", "firezone-tunnel", "libs-common", + "rand", "serde", "serde_json", "tokio", + "tokio-tungstenite", "tracing", + "url", "webrtc", ] @@ -1743,20 +1747,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "jni" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" -dependencies = [ - "cesu8", - "combine", - "jni-sys", - "log", - "thiserror", - "walkdir", -] - [[package]] name = "jni" version = "0.21.1" @@ -1838,8 +1828,6 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" name = "libs-common" version = "0.1.0" dependencies = [ - "async-trait", - "backoff", "base64 0.21.4", "boringtun", "chrono", diff --git a/rust/connlib/libs/client/Cargo.toml b/rust/connlib/libs/client/Cargo.toml index 45c02f75c..1bfec5819 100644 --- a/rust/connlib/libs/client/Cargo.toml +++ b/rust/connlib/libs/client/Cargo.toml @@ -19,6 +19,9 @@ serde = { version = "1.0", default-features = false, features = ["std", "derive" boringtun = { workspace = true } backoff = { workspace = true } webrtc = "0.8" +url = { version = "2.4.1", default-features = false } +rand = { version = "0.8", default-features = false, features = ["std"] } +tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } [target.'cfg(target_os = "android")'.dependencies] tracing = { workspace = true, features = ["std", "attributes"] } diff --git a/rust/connlib/libs/client/src/control.rs b/rust/connlib/libs/client/src/control.rs index 27e3c9b2d..6c0b4d5d0 100644 --- a/rust/connlib/libs/client/src/control.rs +++ b/rust/connlib/libs/client/src/control.rs @@ -1,15 +1,13 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use crate::messages::{ BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages, GatewayIceCandidates, InitClient, Messages, }; -use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; -use boringtun::x25519::StaticSecret; use libs_common::{ - control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic, Reference}, + control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference}, messages::{GatewayId, ResourceDescription, ResourceId}, - Callbacks, ControlSession, + Callbacks, Error::{self, ControlProtocolError}, Result, }; @@ -17,7 +15,7 @@ use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use async_trait::async_trait; use firezone_tunnel::{ConnId, ControlSignal, Request, Tunnel}; -use tokio::sync::{mpsc::Receiver, Mutex}; +use tokio::sync::Mutex; #[async_trait] impl ControlSignal for ControlSignaler { @@ -67,42 +65,20 @@ impl ControlSignal for ControlSignaler { } } -/// Implementation of [ControlSession] for clients. pub struct ControlPlane { - tunnel: Arc>, - control_signaler: ControlSignaler, - tunnel_init: Mutex, + pub tunnel: Arc>, + pub control_signaler: ControlSignaler, + pub tunnel_init: Mutex, } #[derive(Clone)] -struct ControlSignaler { - control_signal: PhoenixSenderWithTopic, +pub struct ControlSignaler { + pub control_signal: PhoenixSenderWithTopic, } impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start( - mut self, - mut receiver: Receiver<(MessageResult, Option)>, - ) -> Result<()> { - let mut interval = tokio::time::interval(Duration::from_secs(10)); - loop { - tokio::select! { - Some((msg, reference)) = receiver.recv() => { - match msg { - Ok(msg) => self.handle_message(msg, reference).await?, - Err(err) => self.handle_error(err, reference).await, - } - }, - _ = interval.tick() => self.stats_event().await, - else => break - } - } - Ok(()) - } - - #[tracing::instrument(level = "trace", skip(self))] - async fn init( + pub async fn init( &mut self, InitClient { interface, @@ -131,7 +107,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - async fn connect( + pub async fn connect( &mut self, Connect { gateway_rtc_session_description, @@ -154,7 +130,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - async fn add_resource(&self, resource_description: ResourceDescription) { + pub async fn add_resource(&self, resource_description: ResourceDescription) { if let Err(e) = self.tunnel.add_resource(resource_description).await { tracing::error!(message = "Can't add resource", error = ?e); let _ = self.tunnel.callbacks().on_error(&e); @@ -248,7 +224,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_message( + pub async fn handle_message( &mut self, msg: Messages, reference: Option, @@ -268,11 +244,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_error( - &mut self, - reply_error: ErrorReply, - reference: Option, - ) { + pub async fn handle_error(&mut self, reply_error: ErrorReply, reference: Option) { if matches!(reply_error.error, ErrorInfo::Offline) { match reference { Some(reference) => { @@ -302,39 +274,7 @@ impl ControlPlane { } } - pub(super) async fn stats_event(&mut self) { + pub async fn stats_event(&mut self) { tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats()); } } - -#[async_trait] -impl ControlSession for ControlPlane { - #[tracing::instrument(level = "trace", skip(private_key, callbacks))] - async fn start( - private_key: StaticSecret, - receiver: Receiver<(MessageResult, Option)>, - control_signal: PhoenixSenderWithTopic, - callbacks: CB, - ) -> Result<()> { - let control_signaler = ControlSignaler { control_signal }; - let tunnel = Arc::new(Tunnel::new(private_key, control_signaler.clone(), callbacks).await?); - - let control_plane = ControlPlane { - tunnel, - control_signaler, - tunnel_init: Mutex::new(false), - }; - - tokio::spawn(async move { control_plane.start(receiver).await }); - - Ok(()) - } - - fn socket_path() -> &'static str { - "client" - } - - fn retry_strategy() -> ExponentialBackoff { - ExponentialBackoffBuilder::default().build() - } -} diff --git a/rust/connlib/libs/client/src/lib.rs b/rust/connlib/libs/client/src/lib.rs index bfbd80773..6642de70c 100644 --- a/rust/connlib/libs/client/src/lib.rs +++ b/rust/connlib/libs/client/src/lib.rs @@ -1,25 +1,246 @@ //! Main connlib library for clients. +pub use libs_common::{get_device_id, messages::ResourceDescription}; +pub use libs_common::{Callbacks, Error}; +pub use tracing_appender::non_blocking::WorkerGuard; + +use crate::control::ControlSignaler; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use boringtun::x25519::{PublicKey, StaticSecret}; use control::ControlPlane; -use messages::EgressMessages; +use firezone_tunnel::Tunnel; +use libs_common::{ + control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result, +}; use messages::IngressMessages; +use messages::Messages; +use messages::ReplyMessages; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use std::sync::Arc; +use std::time::Duration; +use tokio::{runtime::Runtime, sync::Mutex}; +use url::Url; mod control; pub mod file_logger; mod messages; -/// Session type for clients. -/// -/// For more information see libs_common docs on [Session][libs_common::Session]. -pub type Session = libs_common::Session< - ControlPlane, - IngressMessages, - EgressMessages, - ReplyMessages, - Messages, - CB, ->; +struct StopRuntime; -pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error}; -use messages::Messages; -use messages::ReplyMessages; -pub use tracing_appender::non_blocking::WorkerGuard; +/// A session is the entry-point for connlib, maintains the runtime and the tunnel. +/// +/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. +pub struct Session { + runtime_stopper: tokio::sync::mpsc::Sender, + pub callbacks: CallbackErrorFacade, +} + +macro_rules! fatal_error { + ($result:expr, $rt:expr, $cb:expr) => { + match $result { + Ok(res) => res, + Err(err) => { + Self::disconnect_inner($rt, $cb, Some(err)); + return; + } + } + }; +} + +impl Session +where + CB: Callbacks + 'static, +{ + /// Starts a session in the background. + /// + /// This will: + /// 1. Create and start a tokio runtime + /// 2. Connect to the control plane to the portal + /// 3. Start the tunnel in the background and forward control plane messages to it. + /// + /// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced. + /// + /// On a fatal error you should call `[Session::disconnect]` and start a new one. + // TODO: token should be something like SecretString but we need to think about FFI compatibility + pub fn connect( + portal_url: impl TryInto, + token: String, + device_id: String, + callbacks: CB, + ) -> Result { + // TODO: We could use tokio::runtime::current() to get the current runtime + // which could work with swift-rust that already runs a runtime. But IDK if that will work + // in all platforms, a couple of new threads shouldn't bother none. + // Big question here however is how do we get the result? We could block here await the result and spawn a new task. + // but then platforms should know that this function is blocking. + + let callbacks = CallbackErrorFacade(callbacks); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let this = Self { + runtime_stopper: tx.clone(), + callbacks, + }; + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + { + let callbacks = this.callbacks.clone(); + let default_panic_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new({ + let tx = tx.clone(); + move |info| { + let tx = tx.clone(); + let err = info + .payload() + .downcast_ref::<&str>() + .map(|s| Error::Panic(s.to_string())) + .unwrap_or(Error::PanicNonStringPayload); + Self::disconnect_inner(tx, &callbacks, Some(err)); + default_panic_hook(info); + } + })); + } + + Self::connect_inner( + &runtime, + tx, + portal_url.try_into().map_err(|_| Error::UriError)?, + token, + device_id, + this.callbacks.clone(), + ); + std::thread::spawn(move || { + rx.blocking_recv(); + runtime.shutdown_background(); + }); + + Ok(this) + } + + fn connect_inner( + runtime: &Runtime, + runtime_stopper: tokio::sync::mpsc::Sender, + portal_url: Url, + token: String, + device_id: String, + callbacks: CallbackErrorFacade, + ) { + runtime.spawn(async move { + let private_key = StaticSecret::random_from_rng(rand::rngs::OsRng); + let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect(); + let external_id = sha256(device_id); + + let connect_url = fatal_error!( + get_websocket_path(portal_url, token, "client", &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix), + runtime_stopper, + &callbacks + ); + + // This is kinda hacky, the buffer size is 1 so that we make sure that we + // process one message at a time, blocking if a previous message haven't been processed + // to force queue ordering. + let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1); + + let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(connect_url, move |msg, reference| { + let control_plane_sender = control_plane_sender.clone(); + async move { + tracing::trace!("Received message: {msg:?}"); + if let Err(e) = control_plane_sender.send((msg, reference)).await { + tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up."); + } + } + }); + + let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("client".to_owned()) }; + let tunnel = fatal_error!( + Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await, + runtime_stopper, + &callbacks + ); + + let mut control_plane = ControlPlane { + tunnel: Arc::new(tunnel), + control_signaler, + tunnel_init: Mutex::new(false), + }; + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + loop { + tokio::select! { + Some((msg, reference)) = control_plane_receiver.recv() => { + match msg { + Ok(msg) => control_plane.handle_message(msg, reference).await?, + Err(err) => control_plane.handle_error(err, reference).await, + } + }, + _ = interval.tick() => control_plane.stats_event().await, + else => break + } + } + Result::Ok(()) + }); + + tokio::spawn(async move { + let mut exponential_backoff = ExponentialBackoffBuilder::default().build(); + loop { + // `connection.start` calls the callback only after connecting + tracing::debug!("Attempting connection to portal..."); + let result = connection.start(vec!["client".to_owned()], || exponential_backoff.reset()).await; + tracing::warn!("Disconnected from the portal"); + if let Err(e) = &result { + tracing::warn!(error = ?e, "Portal connection error"); + } + if let Some(t) = exponential_backoff.next_backoff() { + tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs()); + let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))); + tokio::time::sleep(t).await; + } else { + tracing::error!("Connection to portal failed, giving up"); + fatal_error!( + result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))), + runtime_stopper, + &callbacks + ); + } + } + + }); + + }); + } + + fn disconnect_inner( + runtime_stopper: tokio::sync::mpsc::Sender, + callbacks: &CallbackErrorFacade, + error: Option, + ) { + // 1. Close the websocket connection + // 2. Free the device handle (Linux) + // 3. Close the file descriptor (Linux/Android) + // 4. Remove the mapping + + // The way we cleanup the tasks is we drop the runtime + // this means we don't need to keep track of different tasks + // but if any of the tasks never yields this will block forever! + // So always yield and if you spawn a blocking tasks rewrite this. + // Furthermore, we will depend on Drop impls to do the list above so, + // implement them :) + // if there's no receiver the runtime is already stopped + // there's an edge case where this is called before the thread is listening for stop threads. + // but I believe in that case the channel will be in a signaled state achieving the same result + + if let Err(err) = runtime_stopper.try_send(StopRuntime) { + tracing::error!("Couldn't stop runtime: {err}"); + } + + let _ = callbacks.on_disconnect(error.as_ref()); + } + + /// Cleanup a [Session]. + /// + /// For now this just drops the runtime, which should drop all pending tasks. + /// Further cleanup should be done here. (Otherwise we can just drop [Session]). + pub fn disconnect(&mut self, error: Option) { + Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error) + } +} diff --git a/rust/connlib/libs/common/Cargo.toml b/rust/connlib/libs/common/Cargo.toml index a8e9ef821..bfebc25d9 100644 --- a/rust/connlib/libs/common/Cargo.toml +++ b/rust/connlib/libs/common/Cargo.toml @@ -6,33 +6,30 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] mock = [] -jni-bindings = ["boringtun/jni-bindings"] [dependencies] base64 = { version = "0.21", default-features = false, features = ["std"] } -serde = { version = "1.0", default-features = false, features = ["derive", "std"] } +boringtun = { workspace = true } +chrono = { workspace = true } futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] } futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } -tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } -webrtc = { version = "0.8" } -uuid = { version = "1.4", default-features = false, features = ["std", "v4", "serde"] } -thiserror = { version = "1.0", default-features = false } -tracing = { workspace = true } -serde_json = { version = "1.0", default-features = false, features = ["std"] } -tokio = { version = "1.32", default-features = false, features = ["rt", "rt-multi-thread"]} -url = { version = "2.4.1", default-features = false } -rand_core = { version = "0.6.4", default-features = false, features = ["std"] } -async-trait = { version = "0.1", default-features = false } -backoff = { workspace = true } ip_network = { version = "0.4", default-features = false, features = ["serde"] } -boringtun = { workspace = true } os_info = { version = "3", default-features = false } -rand = { version = "0.8", default-features = false, features = ["std"] } -chrono = { workspace = true } parking_lot = "0.12" -ring = "0.16" +rand = { version = "0.8", default-features = false, features = ["std"] } +rand_core = { version = "0.6.4", default-features = false, features = ["std"] } +serde = { version = "1.0", default-features = false, features = ["derive", "std"] } +serde_json = { version = "1.0", default-features = false, features = ["std"] } +thiserror = { version = "1.0", default-features = false } +tokio = { version = "1.32", default-features = false, features = ["rt", "rt-multi-thread"]} tokio-stream = { version = "0.1", features = ["time"] } +tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } +tracing = { workspace = true } tracing-appender = "0.2" +url = { version = "2.4.1", default-features = false } +uuid = { version = "1.4", default-features = false, features = ["std", "v4", "serde"] } +webrtc = { version = "0.8" } +ring = "0.16" # Needed for Android logging until tracing is working log = "0.4" diff --git a/rust/connlib/libs/common/src/callbacks.rs b/rust/connlib/libs/common/src/callbacks.rs new file mode 100644 index 000000000..67f402005 --- /dev/null +++ b/rust/connlib/libs/common/src/callbacks.rs @@ -0,0 +1,66 @@ +use crate::messages::ResourceDescription; +use ip_network::IpNetwork; +use std::error::Error; +use std::fmt::{Debug, Display}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +// Avoids having to map types for Windows +type RawFd = i32; + +/// Traits that will be used by connlib to callback the client upper layers. +pub trait Callbacks: Clone + Send + Sync { + /// Error returned when a callback fails. + type Error: Debug + Display + Error; + + /// Called when the tunnel address is set. + fn on_set_interface_config( + &self, + _: Ipv4Addr, + _: Ipv6Addr, + _: Ipv4Addr, + _: String, + ) -> Result { + Ok(-1) + } + + /// Called when the tunnel is connected. + fn on_tunnel_ready(&self) -> Result<(), Self::Error> { + tracing::trace!("tunnel_connected"); + Ok(()) + } + + /// Called when when a route is added. + fn on_add_route(&self, _: IpNetwork) -> Result<(), Self::Error> { + Ok(()) + } + + /// Called when when a route is removed. + fn on_remove_route(&self, _: IpNetwork) -> Result<(), Self::Error> { + Ok(()) + } + + /// Called when the resource list changes. + fn on_update_resources( + &self, + resource_list: Vec, + ) -> Result<(), Self::Error> { + tracing::trace!(?resource_list, "resource_updated"); + Ok(()) + } + + /// Called when the tunnel is disconnected. + /// + /// If the tunnel disconnected due to a fatal error, `error` is the error + /// that caused the disconnect. + fn on_disconnect(&self, error: Option<&crate::Error>) -> Result<(), Self::Error> { + tracing::trace!(error = ?error, "tunnel_disconnected"); + // Note that we can't panic here, since we already hooked the panic to this function. + std::process::exit(0); + } + + /// Called when there's a recoverable error. + fn on_error(&self, error: &crate::Error) -> Result<(), Self::Error> { + tracing::warn!(error = ?error); + Ok(()) + } +} diff --git a/rust/connlib/libs/common/src/callbacks_error_facade.rs b/rust/connlib/libs/common/src/callbacks_error_facade.rs new file mode 100644 index 000000000..e9d9b1882 --- /dev/null +++ b/rust/connlib/libs/common/src/callbacks_error_facade.rs @@ -0,0 +1,96 @@ +use crate::messages::ResourceDescription; +use crate::{Callbacks, Error, Result}; +use ip_network::IpNetwork; +use std::net::{Ipv4Addr, Ipv6Addr}; + +// Avoids having to map types for Windows +type RawFd = i32; + +#[derive(Clone)] +pub struct CallbackErrorFacade(pub CB); + +impl Callbacks for CallbackErrorFacade { + type Error = Error; + + fn on_set_interface_config( + &self, + tunnel_address_v4: Ipv4Addr, + tunnel_address_v6: Ipv6Addr, + dns_address: Ipv4Addr, + dns_fallback_strategy: String, + ) -> Result { + let result = self + .0 + .on_set_interface_config( + tunnel_address_v4, + tunnel_address_v6, + dns_address, + dns_fallback_strategy, + ) + .map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string())); + if let Err(err) = result.as_ref() { + tracing::error!("{err}"); + } + result + } + + fn on_tunnel_ready(&self) -> Result<()> { + let result = self + .0 + .on_tunnel_ready() + .map_err(|err| Error::OnTunnelReadyFailed(err.to_string())); + if let Err(err) = result.as_ref() { + tracing::error!("{err}"); + } + result + } + + fn on_add_route(&self, route: IpNetwork) -> Result<()> { + let result = self + .0 + .on_add_route(route) + .map_err(|err| Error::OnAddRouteFailed(err.to_string())); + if let Err(err) = result.as_ref() { + tracing::error!("{err}"); + } + result + } + + fn on_remove_route(&self, route: IpNetwork) -> Result<()> { + let result = self + .0 + .on_remove_route(route) + .map_err(|err| Error::OnRemoveRouteFailed(err.to_string())); + if let Err(err) = result.as_ref() { + tracing::error!("{err}"); + } + result + } + + fn on_update_resources(&self, resource_list: Vec) -> Result<()> { + let result = self + .0 + .on_update_resources(resource_list) + .map_err(|err| Error::OnUpdateResourcesFailed(err.to_string())); + if let Err(err) = result.as_ref() { + tracing::error!("{err}"); + } + result + } + + fn on_disconnect(&self, error: Option<&Error>) -> Result<()> { + if let Err(err) = self.0.on_disconnect(error) { + tracing::error!("`on_disconnect` failed: {err}"); + } + // There's nothing we can really do if `on_disconnect` fails. + Ok(()) + } + + fn on_error(&self, error: &Error) -> Result<()> { + if let Err(err) = self.0.on_error(error) { + tracing::error!("`on_error` failed: {err}"); + } + // There's nothing we really want to do if `on_error` fails. + Ok(()) + } +} diff --git a/rust/connlib/libs/common/src/lib.rs b/rust/connlib/libs/common/src/lib.rs index 372c14f5c..60ddf81b2 100644 --- a/rust/connlib/libs/common/src/lib.rs +++ b/rust/connlib/libs/common/src/lib.rs @@ -3,17 +3,23 @@ //! This includes types provided by external crates, i.e. [boringtun] to make sure that //! we are using the same version across our own crates. -pub mod error; - -mod session; - +mod callbacks; +mod callbacks_error_facade; pub mod control; +pub mod error; pub mod messages; +pub use callbacks::Callbacks; +pub use callbacks_error_facade::CallbackErrorFacade; pub use error::ConnlibError as Error; pub use error::Result; -pub use session::{CallbackErrorFacade, Callbacks, ControlSession, Session, DNS_SENTINEL}; +use messages::Key; +use ring::digest::{Context, SHA256}; +use std::net::Ipv4Addr; +use url::Url; + +pub const DNS_SENTINEL: Ipv4Addr = Ipv4Addr::new(100, 100, 111, 1); const VERSION: &str = env!("CARGO_PKG_VERSION"); const LIB_NAME: &str = "connlib"; @@ -47,3 +53,55 @@ pub fn get_device_id() -> String { uuid::Uuid::new_v4().to_string() } } + +pub fn set_ws_scheme(url: &mut Url) -> Result<()> { + let scheme = match url.scheme() { + "http" | "ws" => "ws", + "https" | "wss" => "wss", + _ => return Err(Error::UriScheme), + }; + url.set_scheme(scheme) + .expect("Developer error: the match before this should make sure we can set this"); + Ok(()) +} + +pub fn sha256(input: String) -> String { + let mut ctx = Context::new(&SHA256); + ctx.update(input.as_bytes()); + let digest = ctx.finish(); + + digest + .as_ref() + .iter() + .map(|b| format!("{:02x}", b)) + .collect() +} + +pub fn get_websocket_path( + mut url: Url, + secret: String, + mode: &str, + public_key: &Key, + external_id: &str, + name_suffix: &str, +) -> Result { + set_ws_scheme(&mut url)?; + + { + let mut paths = url.path_segments_mut().map_err(|_| Error::UriError)?; + paths.pop_if_empty(); + paths.push(mode); + paths.push("websocket"); + } + + { + let mut query_pairs = url.query_pairs_mut(); + query_pairs.clear(); + query_pairs.append_pair("token", &secret); + query_pairs.append_pair("public_key", &public_key.to_string()); + query_pairs.append_pair("external_id", external_id); + query_pairs.append_pair("name_suffix", name_suffix); + } + + Ok(url) +} diff --git a/rust/connlib/libs/common/src/session.rs b/rust/connlib/libs/common/src/session.rs deleted file mode 100644 index e2a150a62..000000000 --- a/rust/connlib/libs/common/src/session.rs +++ /dev/null @@ -1,456 +0,0 @@ -use async_trait::async_trait; -use backoff::{backoff::Backoff, ExponentialBackoff}; -use boringtun::x25519::{PublicKey, StaticSecret}; -use ip_network::IpNetwork; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use rand_core::OsRng; -use ring::digest::{Context, SHA256}; -use std::{ - error::Error as StdError, - fmt::{Debug, Display}, - marker::PhantomData, - net::{Ipv4Addr, Ipv6Addr}, - result::Result as StdResult, -}; -use tokio::{runtime::Runtime, sync::mpsc::Receiver}; -use url::Url; - -use crate::{ - control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic, Reference}, - messages::{Key, ResourceDescription}, - Error, Result, -}; - -pub const DNS_SENTINEL: Ipv4Addr = Ipv4Addr::new(100, 100, 111, 1); - -// Avoids having to map types for Windows -type RawFd = i32; - -struct StopRuntime; - -// TODO: Not the most tidy trait for a control-plane. -/// Trait that represents a control-plane. -#[async_trait] -pub trait ControlSession { - /// Start control-plane with the given private-key in the background. - async fn start( - private_key: StaticSecret, - receiver: Receiver<(MessageResult, Option)>, - control_signal: PhoenixSenderWithTopic, - callbacks: CB, - ) -> Result<()>; - - /// Either "gateway" or "client" used to get the control-plane URL. - fn socket_path() -> &'static str; - - /// Retry strategy in case of disconnection for the session. - fn retry_strategy() -> ExponentialBackoff; -} - -// TODO: Currently I'm using Session for both gateway and clients -// however, gateway could use the runtime directly and could make things easier -// so revisit this. -/// A session is the entry-point for connlib, maintains the runtime and the tunnel. -/// -/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. -pub struct Session { - runtime_stopper: tokio::sync::mpsc::Sender, - pub callbacks: CallbackErrorFacade, - _phantom: PhantomData<(T, U, V, R, M)>, -} - -/// Traits that will be used by connlib to callback the client upper layers. -pub trait Callbacks: Clone + Send + Sync { - /// Error returned when a callback fails. - type Error: Debug + Display + StdError; - - /// Called when the tunnel address is set. - fn on_set_interface_config( - &self, - _: Ipv4Addr, - _: Ipv6Addr, - _: Ipv4Addr, - _: String, - ) -> StdResult { - Ok(-1) - } - - /// Called when the tunnel is connected. - fn on_tunnel_ready(&self) -> StdResult<(), Self::Error> { - tracing::trace!("tunnel_connected"); - Ok(()) - } - - /// Called when when a route is added. - fn on_add_route(&self, _: IpNetwork) -> StdResult<(), Self::Error> { - Ok(()) - } - - /// Called when when a route is removed. - fn on_remove_route(&self, _: IpNetwork) -> StdResult<(), Self::Error> { - Ok(()) - } - - /// Called when the resource list changes. - fn on_update_resources( - &self, - resource_list: Vec, - ) -> StdResult<(), Self::Error> { - tracing::trace!(?resource_list, "resource_updated"); - Ok(()) - } - - /// Called when the tunnel is disconnected. - /// - /// If the tunnel disconnected due to a fatal error, `error` is the error - /// that caused the disconnect. - fn on_disconnect(&self, error: Option<&Error>) -> StdResult<(), Self::Error> { - tracing::trace!(error = ?error, "tunnel_disconnected"); - // Note that we can't panic here, since we already hooked the panic to this function. - std::process::exit(0); - } - - /// Called when there's a recoverable error. - fn on_error(&self, error: &Error) -> StdResult<(), Self::Error> { - tracing::warn!(error = ?error); - Ok(()) - } -} - -#[derive(Clone)] -pub struct CallbackErrorFacade(pub CB); - -impl Callbacks for CallbackErrorFacade { - type Error = Error; - - fn on_set_interface_config( - &self, - tunnel_address_v4: Ipv4Addr, - tunnel_address_v6: Ipv6Addr, - dns_address: Ipv4Addr, - dns_fallback_strategy: String, - ) -> Result { - let result = self - .0 - .on_set_interface_config( - tunnel_address_v4, - tunnel_address_v6, - dns_address, - dns_fallback_strategy, - ) - .map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string())); - if let Err(err) = result.as_ref() { - tracing::error!("{err}"); - } - result - } - - fn on_tunnel_ready(&self) -> Result<()> { - let result = self - .0 - .on_tunnel_ready() - .map_err(|err| Error::OnTunnelReadyFailed(err.to_string())); - if let Err(err) = result.as_ref() { - tracing::error!("{err}"); - } - result - } - - fn on_add_route(&self, route: IpNetwork) -> Result<()> { - let result = self - .0 - .on_add_route(route) - .map_err(|err| Error::OnAddRouteFailed(err.to_string())); - if let Err(err) = result.as_ref() { - tracing::error!("{err}"); - } - result - } - - fn on_remove_route(&self, route: IpNetwork) -> Result<()> { - let result = self - .0 - .on_remove_route(route) - .map_err(|err| Error::OnRemoveRouteFailed(err.to_string())); - if let Err(err) = result.as_ref() { - tracing::error!("{err}"); - } - result - } - - fn on_update_resources(&self, resource_list: Vec) -> Result<()> { - let result = self - .0 - .on_update_resources(resource_list) - .map_err(|err| Error::OnUpdateResourcesFailed(err.to_string())); - if let Err(err) = result.as_ref() { - tracing::error!("{err}"); - } - result - } - - fn on_disconnect(&self, error: Option<&Error>) -> Result<()> { - if let Err(err) = self.0.on_disconnect(error) { - tracing::error!("`on_disconnect` failed: {err}"); - } - // There's nothing we can really do if `on_disconnect` fails. - Ok(()) - } - - fn on_error(&self, error: &Error) -> Result<()> { - if let Err(err) = self.0.on_error(error) { - tracing::error!("`on_error` failed: {err}"); - } - // There's nothing we really want to do if `on_error` fails. - Ok(()) - } -} - -macro_rules! fatal_error { - ($result:expr, $rt:expr, $cb:expr) => { - match $result { - Ok(res) => res, - Err(err) => { - Self::disconnect_inner($rt, $cb, Some(err)); - return; - } - } - }; -} - -impl Session -where - T: ControlSession, - U: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static, - R: for<'de> serde::Deserialize<'de> + std::fmt::Debug + Send + 'static, - V: serde::Serialize + Send + 'static, - M: From + From + Send + 'static + std::fmt::Debug, - CB: Callbacks + 'static, -{ - /// Starts a session in the background. - /// - /// This will: - /// 1. Create and start a tokio runtime - /// 2. Connect to the control plane to the portal - /// 3. Start the tunnel in the background and forward control plane messages to it. - /// - /// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced. - /// - /// On a fatal error you should call `[Session::disconnect]` and start a new one. - // TODO: token should be something like SecretString but we need to think about FFI compatibility - pub fn connect( - portal_url: impl TryInto, - token: String, - device_id: String, - callbacks: CB, - ) -> Result { - // TODO: We could use tokio::runtime::current() to get the current runtime - // which could work with swift-rust that already runs a runtime. But IDK if that will work - // in all platforms, a couple of new threads shouldn't bother none. - // Big question here however is how do we get the result? We could block here await the result and spawn a new task. - // but then platforms should know that this function is blocking. - - let callbacks = CallbackErrorFacade(callbacks); - let (tx, mut rx) = tokio::sync::mpsc::channel(1); - let this = Self { - runtime_stopper: tx.clone(), - callbacks, - _phantom: PhantomData, - }; - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()?; - { - let callbacks = this.callbacks.clone(); - let default_panic_hook = std::panic::take_hook(); - std::panic::set_hook(Box::new({ - let tx = tx.clone(); - move |info| { - let tx = tx.clone(); - let err = info - .payload() - .downcast_ref::<&str>() - .map(|s| Error::Panic(s.to_string())) - .unwrap_or(Error::PanicNonStringPayload); - Self::disconnect_inner(tx, &callbacks, Some(err)); - default_panic_hook(info); - } - })); - } - - Self::connect_inner( - &runtime, - tx, - portal_url.try_into().map_err(|_| Error::UriError)?, - token, - device_id, - this.callbacks.clone(), - ); - std::thread::spawn(move || { - rx.blocking_recv(); - runtime.shutdown_background(); - }); - - Ok(this) - } - - fn connect_inner( - runtime: &Runtime, - runtime_stopper: tokio::sync::mpsc::Sender, - portal_url: Url, - token: String, - device_id: String, - callbacks: CallbackErrorFacade, - ) { - runtime.spawn(async move { - let private_key = StaticSecret::random_from_rng(OsRng); - let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect(); - let external_id = sha256(device_id); - - let connect_url = fatal_error!( - get_websocket_path(portal_url, token, T::socket_path(), &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix), - runtime_stopper, - &callbacks - ); - - - // This is kinda hacky, the buffer size is 1 so that we make sure that we - // process one message at a time, blocking if a previous message haven't been processed - // to force queue ordering. - let (control_plane_sender, control_plane_receiver) = tokio::sync::mpsc::channel(1); - - let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg, reference| { - let control_plane_sender = control_plane_sender.clone(); - async move { - tracing::trace!("Received message: {msg:?}"); - if let Err(e) = control_plane_sender.send((msg, reference)).await { - tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up."); - } - } - }); - - // Used to send internal messages - let topic = T::socket_path().to_string(); - let internal_sender = connection.sender_with_topic(topic.clone()); - fatal_error!( - T::start(private_key, control_plane_receiver, internal_sender, callbacks.0.clone()).await, - runtime_stopper, - &callbacks - ); - - tokio::spawn(async move { - let mut exponential_backoff = T::retry_strategy(); - loop { - // `connection.start` calls the callback only after connecting - tracing::debug!("Attempting connection to portal..."); - let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await; - tracing::warn!("Disconnected from the portal"); - if let Err(e) = &result { - tracing::warn!(error = ?e, "Portal connection error"); - } - if let Some(t) = exponential_backoff.next_backoff() { - tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs()); - let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))); - tokio::time::sleep(t).await; - } else { - tracing::error!("Connection to portal failed, giving up"); - fatal_error!( - result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))), - runtime_stopper, - &callbacks - ); - } - } - - }); - - }); - } - - fn disconnect_inner( - runtime_stopper: tokio::sync::mpsc::Sender, - callbacks: &CallbackErrorFacade, - error: Option, - ) { - // 1. Close the websocket connection - // 2. Free the device handle (Linux) - // 3. Close the file descriptor (Linux/Android) - // 4. Remove the mapping - - // The way we cleanup the tasks is we drop the runtime - // this means we don't need to keep track of different tasks - // but if any of the tasks never yields this will block forever! - // So always yield and if you spawn a blocking tasks rewrite this. - // Furthermore, we will depend on Drop impls to do the list above so, - // implement them :) - // if there's no receiver the runtime is already stopped - // there's an edge case where this is called before the thread is listening for stop threads. - // but I believe in that case the channel will be in a signaled state achieving the same result - - if let Err(err) = runtime_stopper.try_send(StopRuntime) { - tracing::error!("Couldn't stop runtime: {err}"); - } - - let _ = callbacks.on_disconnect(error.as_ref()); - } - - /// Cleanup a [Session]. - /// - /// For now this just drops the runtime, which should drop all pending tasks. - /// Further cleanup should be done here. (Otherwise we can just drop [Session]). - pub fn disconnect(&mut self, error: Option) { - Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error) - } -} - -fn set_ws_scheme(url: &mut Url) -> Result<()> { - let scheme = match url.scheme() { - "http" | "ws" => "ws", - "https" | "wss" => "wss", - _ => return Err(Error::UriScheme), - }; - url.set_scheme(scheme) - .expect("Developer error: the match before this should make sure we can set this"); - Ok(()) -} - -fn sha256(input: String) -> String { - let mut ctx = Context::new(&SHA256); - ctx.update(input.as_bytes()); - let digest = ctx.finish(); - - digest - .as_ref() - .iter() - .map(|b| format!("{:02x}", b)) - .collect() -} - -fn get_websocket_path( - mut url: Url, - secret: String, - mode: &str, - public_key: &Key, - external_id: &str, - name_suffix: &str, -) -> Result { - set_ws_scheme(&mut url)?; - - { - let mut paths = url.path_segments_mut().map_err(|_| Error::UriError)?; - paths.pop_if_empty(); - paths.push(mode); - paths.push("websocket"); - } - - { - let mut query_pairs = url.query_pairs_mut(); - query_pairs.clear(); - query_pairs.append_pair("token", &secret); - query_pairs.append_pair("public_key", &public_key.to_string()); - query_pairs.append_pair("external_id", external_id); - query_pairs.append_pair("name_suffix", name_suffix); - } - - Ok(url) -} diff --git a/rust/connlib/libs/gateway/Cargo.toml b/rust/connlib/libs/gateway/Cargo.toml index 37cecf21b..d55ee0a8f 100644 --- a/rust/connlib/libs/gateway/Cargo.toml +++ b/rust/connlib/libs/gateway/Cargo.toml @@ -14,6 +14,9 @@ boringtun = { workspace = true } chrono = { workspace = true } backoff = { workspace = true } webrtc = "0.8" +url = { version = "2.4.1", default-features = false } +rand = { version = "0.8", default-features = false, features = ["std"] } +tokio-tungstenite = { version = "0.19", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } [dev-dependencies] serde_json = { version = "1.0", default-features = false, features = ["std"] } diff --git a/rust/connlib/libs/gateway/src/control.rs b/rust/connlib/libs/gateway/src/control.rs index 32ac97f20..e02d121fd 100644 --- a/rust/connlib/libs/gateway/src/control.rs +++ b/rust/connlib/libs/gateway/src/control.rs @@ -1,34 +1,26 @@ -use std::{sync::Arc, time::Duration}; - -use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; -use boringtun::x25519::StaticSecret; -use firezone_tunnel::{ConnId, ControlSignal, Tunnel}; -use libs_common::{ - control::{MessageResult, PhoenixSenderWithTopic, Reference}, - messages::{GatewayId, ResourceDescription}, - Callbacks, ControlSession, - Error::ControlProtocolError, - Result, -}; -use tokio::sync::mpsc::Receiver; -use webrtc::ice_transport::ice_candidate::RTCIceCandidate; - -use crate::messages::{AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates}; - use super::messages::{ ConnectionReady, EgressMessages, IngressMessages, InitGateway, RequestConnection, }; - +use crate::messages::{AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates}; use async_trait::async_trait; +use firezone_tunnel::{ConnId, ControlSignal, Tunnel}; +use libs_common::Error::ControlProtocolError; +use libs_common::{ + control::PhoenixSenderWithTopic, + messages::{GatewayId, ResourceDescription}, + Callbacks, Result, +}; +use std::sync::Arc; +use webrtc::ice_transport::ice_candidate::RTCIceCandidate; pub struct ControlPlane { - tunnel: Arc>, - control_signaler: ControlSignaler, + pub tunnel: Arc>, + pub control_signaler: ControlSignaler, } #[derive(Clone)] -struct ControlSignaler { - control_signal: PhoenixSenderWithTopic, +pub struct ControlSignaler { + pub control_signal: PhoenixSenderWithTopic, } #[async_trait] @@ -71,28 +63,7 @@ impl ControlSignal for ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start( - mut self, - mut receiver: Receiver<(MessageResult, Option)>, - ) -> Result<()> { - let mut interval = tokio::time::interval(Duration::from_secs(10)); - loop { - tokio::select! { - Some((msg, _)) = receiver.recv() => { - match msg { - Ok(msg) => self.handle_message(msg).await?, - Err(_msg_reply) => todo!(), - } - }, - _ = interval.tick() => self.stats_event().await, - else => break - } - } - Ok(()) - } - - #[tracing::instrument(level = "trace", skip(self))] - async fn init(&mut self, init: InitGateway) -> Result<()> { + pub async fn init(&mut self, init: InitGateway) -> Result<()> { if let Err(e) = self.tunnel.set_interface(&init.interface).await { tracing::error!("Couldn't initialize interface: {e}"); Err(e) @@ -104,7 +75,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - fn connection_request(&self, connection_request: RequestConnection) { + pub fn connection_request(&self, connection_request: RequestConnection) { let tunnel = Arc::clone(&self.tunnel); let mut control_signaler = self.control_signaler.clone(); tokio::spawn(async move { @@ -141,7 +112,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - fn allow_access( + pub fn allow_access( &self, AllowAccess { client_id, @@ -172,7 +143,7 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> { + pub async fn handle_message(&mut self, msg: IngressMessages) -> Result<()> { match msg { IngressMessages::Init(init) => self.init(init).await?, IngressMessages::RequestConnection(connection_request) => { @@ -189,40 +160,7 @@ impl ControlPlane { Ok(()) } - pub(super) async fn stats_event(&mut self) { + pub async fn stats_event(&mut self) { tracing::debug!(target: "tunnel_state", stats = ?self.tunnel.stats()); } } - -#[async_trait] -impl ControlSession for ControlPlane { - #[tracing::instrument(level = "trace", skip(private_key, callbacks))] - async fn start( - private_key: StaticSecret, - receiver: Receiver<(MessageResult, Option)>, - control_signal: PhoenixSenderWithTopic, - callbacks: CB, - ) -> Result<()> { - let control_signaler = ControlSignaler { control_signal }; - let tunnel = Arc::new(Tunnel::new(private_key, control_signaler.clone(), callbacks).await?); - - let control_plane = ControlPlane { - tunnel, - control_signaler, - }; - - tokio::spawn(async move { control_plane.start(receiver).await }); - - Ok(()) - } - - fn socket_path() -> &'static str { - "gateway" - } - - fn retry_strategy() -> ExponentialBackoff { - ExponentialBackoffBuilder::default() - .with_max_elapsed_time(None) - .build() - } -} diff --git a/rust/connlib/libs/gateway/src/lib.rs b/rust/connlib/libs/gateway/src/lib.rs index f7956e395..32b10c6a8 100644 --- a/rust/connlib/libs/gateway/src/lib.rs +++ b/rust/connlib/libs/gateway/src/lib.rs @@ -1,22 +1,245 @@ //! Main connlib library for gateway. +pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error}; + +use crate::control::ControlSignaler; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use boringtun::x25519::{PublicKey, StaticSecret}; use control::ControlPlane; -use messages::EgressMessages; +use firezone_tunnel::Tunnel; +use libs_common::{ + control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result, +}; use messages::IngressMessages; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use std::sync::Arc; +use std::time::Duration; +use tokio::runtime::Runtime; +use url::Url; mod control; mod messages; -/// Session type for gateway. -/// -/// For more information see libs_common docs on [Session][libs_common::Session]. -// TODO: Still working on gateway messages -pub type Session = libs_common::Session< - ControlPlane, - IngressMessages, - EgressMessages, - IngressMessages, - IngressMessages, - CB, ->; +struct StopRuntime; -pub use libs_common::{get_device_id, messages::ResourceDescription, Callbacks, Error}; +/// A session is the entry-point for connlib, maintains the runtime and the tunnel. +/// +/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. +pub struct Session { + runtime_stopper: tokio::sync::mpsc::Sender, + pub callbacks: CallbackErrorFacade, +} + +macro_rules! fatal_error { + ($result:expr, $rt:expr, $cb:expr) => { + match $result { + Ok(res) => res, + Err(err) => { + Self::disconnect_inner($rt, $cb, Some(err)); + return; + } + } + }; +} + +impl Session +where + CB: Callbacks + 'static, +{ + /// Starts a session in the background. + /// + /// This will: + /// 1. Create and start a tokio runtime + /// 2. Connect to the control plane to the portal + /// 3. Start the tunnel in the background and forward control plane messages to it. + /// + /// The generic parameter `CB` should implement all the handlers and that's how errors will be surfaced. + /// + /// On a fatal error you should call `[Session::disconnect]` and start a new one. + // TODO: token should be something like SecretString but we need to think about FFI compatibility + pub fn connect( + portal_url: impl TryInto, + token: String, + device_id: String, + callbacks: CB, + ) -> Result { + // TODO: We could use tokio::runtime::current() to get the current runtime + // which could work with swift-rust that already runs a runtime. But IDK if that will work + // in all platforms, a couple of new threads shouldn't bother none. + // Big question here however is how do we get the result? We could block here await the result and spawn a new task. + // but then platforms should know that this function is blocking. + + let callbacks = CallbackErrorFacade(callbacks); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let this = Self { + runtime_stopper: tx.clone(), + callbacks, + }; + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + { + let callbacks = this.callbacks.clone(); + let default_panic_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new({ + let tx = tx.clone(); + move |info| { + let tx = tx.clone(); + let err = info + .payload() + .downcast_ref::<&str>() + .map(|s| Error::Panic(s.to_string())) + .unwrap_or(Error::PanicNonStringPayload); + Self::disconnect_inner(tx, &callbacks, Some(err)); + default_panic_hook(info); + } + })); + } + + Self::connect_inner( + &runtime, + tx, + portal_url.try_into().map_err(|_| Error::UriError)?, + token, + device_id, + this.callbacks.clone(), + ); + std::thread::spawn(move || { + rx.blocking_recv(); + runtime.shutdown_background(); + }); + + Ok(this) + } + + fn connect_inner( + runtime: &Runtime, + runtime_stopper: tokio::sync::mpsc::Sender, + portal_url: Url, + token: String, + device_id: String, + callbacks: CallbackErrorFacade, + ) { + runtime.spawn(async move { + let private_key = StaticSecret::random_from_rng(rand::rngs::OsRng); + let name_suffix: String = thread_rng().sample_iter(&Alphanumeric).take(8).map(char::from).collect(); + let external_id = sha256(device_id); + + let connect_url = fatal_error!( + get_websocket_path(portal_url, token, "gateway", &Key(PublicKey::from(&private_key).to_bytes()), &external_id, &name_suffix), + runtime_stopper, + &callbacks + ); + + + // This is kinda hacky, the buffer size is 1 so that we make sure that we + // process one message at a time, blocking if a previous message haven't been processed + // to force queue ordering. + let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1); + + let mut connection = PhoenixChannel::<_, IngressMessages, IngressMessages, IngressMessages>::new(connect_url, move |msg, reference| { + let control_plane_sender = control_plane_sender.clone(); + async move { + tracing::trace!("Received message: {msg:?}"); + if let Err(e) = control_plane_sender.send((msg, reference)).await { + tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up."); + } + } + }); + + // Used to send internal messages + let control_signaler = ControlSignaler { control_signal: connection.sender_with_topic("gateway".to_owned()) }; + let tunnel = fatal_error!( + Tunnel::new(private_key, control_signaler.clone(), callbacks.clone()).await, + runtime_stopper, + &callbacks + ); + + let mut control_plane = ControlPlane { + tunnel: Arc::new(tunnel), + control_signaler, + }; + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + loop { + tokio::select! { + Some((msg, _)) = control_plane_receiver.recv() => { + match msg { + Ok(msg) => control_plane.handle_message(msg).await?, + Err(_msg_reply) => todo!(), + } + }, + _ = interval.tick() => control_plane.stats_event().await, + else => break + } + } + + Result::Ok(()) + }); + + tokio::spawn(async move { + let mut exponential_backoff = ExponentialBackoffBuilder::default() + .with_max_elapsed_time(None) + .build(); + loop { + // `connection.start` calls the callback only after connecting + tracing::debug!("Attempting connection to portal..."); + let result = connection.start(vec!["gateway".to_owned()], || exponential_backoff.reset()).await; + tracing::warn!("Disconnected from the portal"); + if let Err(e) = &result { + tracing::warn!(error = ?e, "Portal connection error"); + } + if let Some(t) = exponential_backoff.next_backoff() { + tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs()); + let _ = callbacks.on_error(&result.err().unwrap_or(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))); + tokio::time::sleep(t).await; + } else { + tracing::error!("Connection to portal failed, giving up"); + fatal_error!( + result.and(Err(Error::PortalConnectionError(tokio_tungstenite::tungstenite::Error::ConnectionClosed))), + runtime_stopper, + &callbacks + ); + } + } + + }); + + }); + } + + fn disconnect_inner( + runtime_stopper: tokio::sync::mpsc::Sender, + callbacks: &CallbackErrorFacade, + error: Option, + ) { + // 1. Close the websocket connection + // 2. Free the device handle (Linux) + // 3. Close the file descriptor (Linux/Android) + // 4. Remove the mapping + + // The way we cleanup the tasks is we drop the runtime + // this means we don't need to keep track of different tasks + // but if any of the tasks never yields this will block forever! + // So always yield and if you spawn a blocking tasks rewrite this. + // Furthermore, we will depend on Drop impls to do the list above so, + // implement them :) + // if there's no receiver the runtime is already stopped + // there's an edge case where this is called before the thread is listening for stop threads. + // but I believe in that case the channel will be in a signaled state achieving the same result + + if let Err(err) = runtime_stopper.try_send(StopRuntime) { + tracing::error!("Couldn't stop runtime: {err}"); + } + + let _ = callbacks.on_disconnect(error.as_ref()); + } + + /// Cleanup a [Session]. + /// + /// For now this just drops the runtime, which should drop all pending tasks. + /// Further cleanup should be done here. (Otherwise we can just drop [Session]). + pub fn disconnect(&mut self, error: Option) { + Self::disconnect_inner(self.runtime_stopper.clone(), &self.callbacks, error) + } +} diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index 03d30cd9b..777f0ba13 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; -use libs_common::{messages::Key, Callbacks, Error, DNS_SENTINEL}; +use libs_common::{messages::Key, CallbackErrorFacade, Callbacks, Error, DNS_SENTINEL}; use serde::{Deserialize, Serialize}; use async_trait::async_trait; @@ -35,7 +35,7 @@ use libs_common::{ messages::{ ClientId, GatewayId, Interface as InterfaceConfig, ResourceDescription, ResourceId, }, - CallbackErrorFacade, Result, + Result, }; use device_channel::{create_iface, DeviceIo, IfaceConfig};