diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 9310343a7..2bde087ad 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1827,6 +1827,7 @@ dependencies = [ "swift-bridge", "thiserror", "tokio", + "tokio-stream", "tokio-tungstenite", "tracing", "url", @@ -3432,6 +3433,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.19.0" diff --git a/rust/connlib/libs/client/src/control.rs b/rust/connlib/libs/client/src/control.rs index 857322993..93179c039 100644 --- a/rust/connlib/libs/client/src/control.rs +++ b/rust/connlib/libs/client/src/control.rs @@ -4,21 +4,22 @@ use crate::messages::{Connect, ConnectionDetails, EgressMessages, InitClient, Me use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; use boringtun::x25519::StaticSecret; use libs_common::{ - control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic}, + control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic, Reference}, messages::{Id, ResourceDescription}, Callbacks, ControlSession, Error, Result, }; use async_trait::async_trait; use firezone_tunnel::{ControlSignal, Request, Tunnel}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::{mpsc::Receiver, Mutex}; #[async_trait] impl ControlSignal for ControlSignaler { async fn signal_connection_to( &self, resource: &ResourceDescription, - connected_gateway_ids: Vec, + connected_gateway_ids: &[Id], + reference: usize, ) -> Result<()> { self.control_signal // It's easier if self is not mut @@ -26,11 +27,9 @@ impl ControlSignal for ControlSignaler { .send_with_ref( EgressMessages::PrepareConnection { resource_id: resource.id(), - connected_gateway_ids, + connected_gateway_ids: connected_gateway_ids.to_vec(), }, - // The resource id functions as the connection id since we can only have one connection - // outgoing for each resource. - resource.id(), + reference, ) .await?; Ok(()) @@ -41,6 +40,7 @@ impl ControlSignal for ControlSignaler { pub struct ControlPlane { tunnel: Arc>, control_signaler: ControlSignaler, + tunnel_init: Mutex, } #[derive(Clone)] @@ -50,14 +50,17 @@ struct ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver>) -> Result<()> { + async fn start( + mut self, + mut receiver: Receiver<(MessageResult, Option)>, + ) -> Result<()> { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { - Some(msg) = receiver.recv() => { + Some((msg, reference)) = receiver.recv() => { match msg { - Ok(msg) => self.handle_message(msg).await?, - Err(msg_reply) => self.handle_error(msg_reply).await, + Ok(msg) => self.handle_message(msg, reference).await?, + Err(err) => self.handle_error(err, reference).await, } }, _ = interval.tick() => self.stats_event().await, @@ -75,16 +78,25 @@ impl ControlPlane { resources, }: InitClient, ) -> Result<()> { - if let Err(e) = self.tunnel.set_interface(&interface).await { - tracing::error!(error = ?e, "Error initializing interface"); - Err(e) - } else { - for resource_description in resources { - self.add_resource(resource_description).await; + { + let mut init = self.tunnel_init.lock().await; + if !*init { + if let Err(e) = self.tunnel.set_interface(&interface).await { + tracing::error!(error = ?e, "Error initializing interface"); + return Err(e); + } else { + *init = true; + tracing::info!("Firezoned Started!"); + } + } else { + tracing::info!("Firezoned reinitializated"); } - tracing::info!("Firezoned Started!"); - Ok(()) } + + for resource_description in resources { + self.add_resource(resource_description).await; + } + Ok(()) } #[tracing::instrument(level = "trace", skip(self))] @@ -137,12 +149,13 @@ impl ControlPlane { relays, .. }: ConnectionDetails, + reference: Option, ) { let tunnel = Arc::clone(&self.tunnel); let mut control_signaler = self.control_signaler.clone(); tokio::spawn(async move { let err = match tunnel - .request_connection(resource_id, gateway_id, relays) + .request_connection(resource_id, gateway_id, relays, reference) .await { Ok(Request::NewConnection(connection_request)) => { @@ -185,11 +198,15 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_message(&mut self, msg: Messages) -> Result<()> { + pub(super) async fn handle_message( + &mut self, + msg: Messages, + reference: Option, + ) -> Result<()> { match msg { Messages::Init(init) => self.init(init).await?, Messages::ConnectionDetails(connection_details) => { - self.connection_details(connection_details) + self.connection_details(connection_details, reference) } Messages::Connect(connect) => self.connect(connect).await, Messages::ResourceAdded(resource) => self.add_resource(resource).await, @@ -200,9 +217,13 @@ impl ControlPlane { } #[tracing::instrument(level = "trace", skip(self))] - pub(super) async fn handle_error(&mut self, reply_error: ErrorReply) { + pub(super) async fn handle_error( + &mut self, + reply_error: ErrorReply, + reference: Option, + ) { if matches!(reply_error.error, ErrorInfo::Offline) { - match reply_error.reference { + match reference { Some(reference) => { let Ok(id) = reference.parse() else { tracing::error!( @@ -240,7 +261,7 @@ impl ControlSession for ControlPlane #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, - receiver: Receiver>, + receiver: Receiver<(MessageResult, Option)>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()> { @@ -250,6 +271,7 @@ impl ControlSession for ControlPlane let control_plane = ControlPlane { tunnel, control_signaler, + tunnel_init: Mutex::new(false), }; tokio::spawn(async move { control_plane.start(receiver).await }); diff --git a/rust/connlib/libs/common/Cargo.toml b/rust/connlib/libs/common/Cargo.toml index 54f564ed6..f147f71ee 100644 --- a/rust/connlib/libs/common/Cargo.toml +++ b/rust/connlib/libs/common/Cargo.toml @@ -31,6 +31,7 @@ rand = { version = "0.8", default-features = false, features = ["std"] } chrono = { workspace = true } parking_lot = "0.12" ring = "0.16" +tokio-stream = { version = "0.1", features = ["time"] } # Needed for Android logging until tracing is working log = "0.4" diff --git a/rust/connlib/libs/common/src/control.rs b/rust/connlib/libs/common/src/control.rs index 4776444ec..a54d3fd94 100644 --- a/rust/connlib/libs/common/src/control.rs +++ b/rust/connlib/libs/common/src/control.rs @@ -11,9 +11,10 @@ use futures::{ channel::mpsc::{channel, Receiver, Sender}, TryStreamExt, }; -use futures_util::{Future, SinkExt, StreamExt}; +use futures_util::{Future, SinkExt, StreamExt, TryFutureExt}; use rand_core::{OsRng, RngCore}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio_stream::StreamExt as _; use tokio_tungstenite::{ connect_async, tungstenite::{self, handshake::client::Request}, @@ -24,6 +25,10 @@ use url::Url; use crate::{get_user_agent, Error, Result}; const CHANNEL_SIZE: usize = 1_000; +const HEARTBEAT: Duration = Duration::from_secs(30); +const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35); + +pub type Reference = String; /// Main struct to interact with the control-protocol channel. /// @@ -79,7 +84,7 @@ where I: DeserializeOwned, R: DeserializeOwned, M: From + From, - F: Fn(MessageResult) -> Fut, + F: Fn(MessageResult, Option) -> Fut, Fut: Future + Send + 'static, { /// Starts the tunnel with the parameters given in [Self::new]. @@ -110,7 +115,10 @@ where handler, receiver, .. } = self; - let process_messages = read.try_for_each(|message| async { + let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| { + m.map_err(Error::from)?.map_err(Error::from) + }) + .try_for_each(|message| async { Self::message_process(handler, message).await; Ok(()) }); @@ -141,13 +149,20 @@ where // Furthermore can this also happen if write errors out? *that* I'd assume is possible... // What option is left? write a new future to forward items. // For now we should never assume that an item arrived the portal because we sent it! - let send_messages = receiver.map(Ok).forward(write); + let send_messages = futures::StreamExt::map(receiver, Ok) + .forward(write) + .map_err(Error::from); let phoenix_heartbeat = tokio::spawn(async move { - let mut timer = tokio::time::interval(Duration::from_secs(30)); + let mut timer = tokio::time::interval(HEARTBEAT); loop { timer.tick().await; - let Ok(_) = sender.send("phoenix", EgressControlMessage::Heartbeat(Empty {})).await else { break }; + let Ok(_) = sender + .send("phoenix", EgressControlMessage::Heartbeat(Empty {})) + .await + else { + break; + }; } }); @@ -174,30 +189,28 @@ where match message.into_text() { Ok(m_str) => match serde_json::from_str::>(&m_str) { Ok(m) => match m.payload { - Payload::Message(m) => handler(Ok(m.into())).await, + Payload::Message(payload) => handler(Ok(payload.into()), m.reference).await, Payload::Reply(status) => match status { ReplyMessage::PhxReply(phx_reply) => match phx_reply { // TODO: Here we should pass error info to a subscriber PhxReply::Error(info) => { tracing::warn!("Portal error: {info:?}"); - handler(Err(ErrorReply { - error: info, - reference: m.reference, - })) - .await + handler(Err(ErrorReply { error: info }), m.reference).await } PhxReply::Ok(reply) => match reply { OkReply::NoMessage(Empty {}) => { - tracing::trace!("Phoenix status message") + tracing::trace!(target: "phoenix_status", "Phoenix status message") + } + OkReply::Message(payload) => { + handler(Ok(payload.into()), m.reference).await } - OkReply::Message(m) => handler(Ok(m.into())).await, }, }, ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"), }, }, Err(e) => { - tracing::error!("Error deserializing message {m_str}: {e:?}"); + tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e); } }, _ => tracing::error!("Received message that is not text"), @@ -254,8 +267,6 @@ pub type MessageResult = std::result::Result; pub struct ErrorReply { /// Information of the error pub error: ErrorInfo, - /// Reference to the message that caused the error - pub reference: Option, } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] diff --git a/rust/connlib/libs/common/src/error.rs b/rust/connlib/libs/common/src/error.rs index ce49eb5dc..34a39ac1f 100644 --- a/rust/connlib/libs/common/src/error.rs +++ b/rust/connlib/libs/common/src/error.rs @@ -21,6 +21,9 @@ pub enum ConnlibError { /// Request error for websocket connection. #[error("Error forming request: {0}")] RequestError(#[from] tokio_tungstenite::tungstenite::http::Error), + /// Websocket heartbeat timedout + #[error("Websocket heartbeat timedout")] + WebsocketTimeout(#[from] tokio_stream::Elapsed), /// Error during websocket connection. #[error("Portal connection error: {0}")] PortalConnectionError(#[from] tokio_tungstenite::tungstenite::error::Error), @@ -99,6 +102,12 @@ pub enum ConnlibError { /// A panic occurred with a non-string payload. #[error("Panicked with a non-string payload")] PanicNonStringPayload, + /// Received connection details that might be stale + #[error("Unexpected connection details")] + UnexpectedConnectionDetails, + /// Invalid phoenix channel reference + #[error("Invalid phoenix channel reply reference")] + InvalidReference, } #[cfg(target_os = "linux")] diff --git a/rust/connlib/libs/common/src/session.rs b/rust/connlib/libs/common/src/session.rs index dc7d54447..d76b5856b 100644 --- a/rust/connlib/libs/common/src/session.rs +++ b/rust/connlib/libs/common/src/session.rs @@ -16,7 +16,7 @@ use tokio::{runtime::Runtime, sync::mpsc::Receiver}; use url::Url; use crate::{ - control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic}, + control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic, Reference}, messages::{Key, ResourceDescription}, Error, Result, }; @@ -33,7 +33,7 @@ pub trait ControlSession { /// Start control-plane with the given private-key in the background. async fn start( private_key: StaticSecret, - receiver: Receiver>, + receiver: Receiver<(MessageResult, Option)>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()>; @@ -292,11 +292,11 @@ where // to force queue ordering. let (control_plane_sender, control_plane_receiver) = tokio::sync::mpsc::channel(1); - let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg| { + let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg, reference| { let control_plane_sender = control_plane_sender.clone(); async move { tracing::trace!("Received message: {msg:?}"); - if let Err(e) = control_plane_sender.send(msg).await { + if let Err(e) = control_plane_sender.send((msg, reference)).await { tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up."); } } @@ -318,8 +318,8 @@ where tracing::debug!("Attempting connection to portal..."); let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await; tracing::warn!("Disconnected from the portal"); - if let Err(err) = &result { - tracing::warn!("Portal connection error: {err}"); + if let Err(e) = &result { + tracing::warn!(error = ?e, "Portal connection error"); } if let Some(t) = exponential_backoff.next_backoff() { tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs()); diff --git a/rust/connlib/libs/gateway/src/control.rs b/rust/connlib/libs/gateway/src/control.rs index 05bd1f9a5..e26f9b13d 100644 --- a/rust/connlib/libs/gateway/src/control.rs +++ b/rust/connlib/libs/gateway/src/control.rs @@ -4,7 +4,7 @@ use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; use boringtun::x25519::StaticSecret; use firezone_tunnel::{ControlSignal, Tunnel}; use libs_common::{ - control::{MessageResult, PhoenixSenderWithTopic}, + control::{MessageResult, PhoenixSenderWithTopic, Reference}, messages::{Id, ResourceDescription}, Callbacks, ControlSession, Result, }; @@ -33,7 +33,8 @@ impl ControlSignal for ControlSignaler { async fn signal_connection_to( &self, resource: &ResourceDescription, - _connected_gateway_ids: Vec, + _connected_gateway_ids: &[Id], + _: usize, ) -> Result<()> { tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients."); Ok(()) @@ -42,11 +43,14 @@ impl ControlSignal for ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver>) -> Result<()> { + async fn start( + mut self, + mut receiver: Receiver<(MessageResult, Option)>, + ) -> Result<()> { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { - Some(msg) = receiver.recv() => { + Some((msg, _)) = receiver.recv() => { match msg { Ok(msg) => self.handle_message(msg).await?, Err(_msg_reply) => todo!(), @@ -144,7 +148,7 @@ impl ControlSession for ControlPla #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, - receiver: Receiver>, + receiver: Receiver<(MessageResult, Option)>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()> { diff --git a/rust/connlib/libs/tunnel/src/control_protocol.rs b/rust/connlib/libs/tunnel/src/control_protocol.rs index dd03003fe..a4c5e9745 100644 --- a/rust/connlib/libs/tunnel/src/control_protocol.rs +++ b/rust/connlib/libs/tunnel/src/control_protocol.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use tracing::instrument; use libs_common::{ + control::Reference, messages::{Id, Key, Relay, RequestConnection, ResourceDescription, ReuseConnection}, Callbacks, Error, Result, }; @@ -255,16 +256,39 @@ where resource_id: Id, gateway_id: Id, relays: Vec, + reference: Option, ) -> Result { - self.resources_gateways - .lock() - .insert(resource_id, gateway_id); + tracing::trace!("Received gateways and relays for resource, requesting connection"); let resource_description = self .resources .read() .get_by_id(&resource_id) .ok_or(Error::UnknownResource)? .clone(); + + let reference: usize = reference + .ok_or(Error::InvalidReference)? + .parse() + .map_err(|_| Error::InvalidReference)?; + { + let mut awaiting_connections = self.awaiting_connection.lock(); + let Some(awaiting_connection) = awaiting_connections.get_mut(&resource_id) else { + return Err(Error::UnexpectedConnectionDetails); + }; + awaiting_connection.response_recieved = true; + if awaiting_connection.total_attemps != reference + || resource_description + .ips() + .iter() + .any(|&ip| self.peers_by_ip.read().exact_match(ip).is_some()) + { + return Err(Error::UnexpectedConnectionDetails); + } + } + + self.resources_gateways + .lock() + .insert(resource_id, gateway_id); { let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock(); if let Some(g) = gateway_awaiting_connection.get_mut(&gateway_id) { @@ -278,23 +302,38 @@ where } } { - let mut peers_by_ip = self.peers_by_ip.write(); - let peer = peers_by_ip - .iter() - .find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p)) - .cloned(); - if let Some(peer) = peer { - for ip in resource_description.ips() { - peer.add_allowed_ip(ip); - peers_by_ip.insert(ip, Arc::clone(&peer)); + let found = { + let mut peers_by_ip = self.peers_by_ip.write(); + let peer = peers_by_ip + .iter() + .find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p)) + .cloned(); + if let Some(peer) = peer { + for ip in resource_description.ips() { + peer.add_allowed_ip(ip); + peers_by_ip.insert(ip, Arc::clone(&peer)); + } + true + } else { + false } + }; + + if found { + self.awaiting_connection.lock().remove(&resource_id); return Ok(Request::ReuseConnection(ReuseConnection { resource_id, gateway_id, })); } } - let peer_connection = self.initialize_peer_request(relays).await?; + let peer_connection = { + let peer_connection = Arc::new(self.initialize_peer_request(relays).await?); + let mut peer_connections = self.peer_connections.lock(); + peer_connections.insert(gateway_id, Arc::clone(&peer_connection)); + peer_connection + }; + self.set_connection_state_update_initiator(&peer_connection, gateway_id, resource_id); let data_channel = peer_connection.create_data_channel("data", None).await?; @@ -360,10 +399,6 @@ where .await .expect("Developer error: set_local_description was just called above"); - self.peer_connections - .lock() - .insert(gateway_id, peer_connection); - Ok(Request::NewConnection(RequestConnection { resource_id, gateway_id, diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index 742663780..392b623f4 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -29,12 +29,7 @@ use webrtc::{ peer_connection::RTCPeerConnection, }; -use std::{ - collections::{HashMap, HashSet}, - net::IpAddr, - sync::Arc, - time::Duration, -}; +use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration}; use libs_common::{ messages::{Id, Interface as InterfaceConfig, ResourceDescription}, @@ -93,6 +88,8 @@ const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1); const HANDSHAKE_RATE_LIMIT: u64 = 100; const MAX_UDP_SIZE: usize = (1 << 16) - 1; +const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2); + /// Represent's the tunnel actual peer's config /// Obtained from libs_common's Peer #[derive(Clone)] @@ -125,10 +122,17 @@ pub trait ControlSignal { async fn signal_connection_to( &self, resource: &ResourceDescription, - connected_gateway_ids: Vec, + connected_gateway_ids: &[Id], + reference: usize, ) -> Result<()>; } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +struct AwaitingConnectionDetails { + pub total_attemps: usize, + pub response_recieved: bool, +} + // TODO: We should use newtypes for each kind of Id /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets /// to communicate between peers. @@ -143,7 +147,7 @@ pub struct Tunnel { public_key: PublicKey, peers_by_ip: RwLock>>, peer_connections: Mutex>>, - awaiting_connection: Mutex>, + awaiting_connection: Mutex>, gateway_awaiting_connection: Mutex>>, resources_gateways: Mutex>, webrtc_api: API, @@ -160,12 +164,13 @@ pub struct TunnelStats { public_key: String, peers_by_ip: HashMap, peer_connections: Vec, - awaiting_connection: HashSet, - gateway_awaiting_connection: HashMap>, resource_gateways: HashMap, dns_resources: HashMap, network_resources: HashMap, gateway_public_keys: HashMap, + + awaiting_connection: HashMap, + gateway_awaiting_connection: HashMap>, } impl Tunnel @@ -654,13 +659,13 @@ where // and we are finding another packet to the same address (otherwise we would just use peer_connections here) let mut awaiting_connection = dev.awaiting_connection.lock(); let id = resource.id(); - if !awaiting_connection.contains(&id) { + if awaiting_connection.get(&id).is_none() { tracing::trace!( message = "Found new intent to send packets to resource", resource_ip = %dst_addr ); - awaiting_connection.insert(id); + awaiting_connection.insert(id, Default::default()); let dev = Arc::clone(&dev); let mut connected_gateway_ids: Vec<_> = dev @@ -676,15 +681,38 @@ where message = "Currently connected gateways", gateways = ?connected_gateway_ids ); tokio::spawn(async move { - if let Err(e) = dev - .control_signaler - .signal_connection_to(&resource, connected_gateway_ids) - .await - { - // Not a deadlock because this is a different task - dev.awaiting_connection.lock().remove(&id); - tracing::error!(message = "couldn't start protocol for new connection to resource", error = ?e); - let _ = dev.callbacks.on_error(&e); + let mut interval = + tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY); + loop { + interval.tick().await; + let reference = { + let mut awaiting_connections = + dev.awaiting_connection.lock(); + let Some(awaiting_connection) = + awaiting_connections.get_mut(&resource.id()) + else { + break; + }; + if awaiting_connection.response_recieved { + break; + } + awaiting_connection.total_attemps += 1; + awaiting_connection.total_attemps + }; + if let Err(e) = dev + .control_signaler + .signal_connection_to( + &resource, + &connected_gateway_ids, + reference, + ) + .await + { + // Not a deadlock because this is a different task + dev.awaiting_connection.lock().remove(&id); + tracing::error!(message = "couldn't start protocol for new connection to resource", error = ?e); + let _ = dev.callbacks.on_error(&e); + } } }); }