mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
Connlib/fix stability issues (#1974)
When we lost networks(or change them), the phoenix channel didn't detect that the connection was lost, since the underlying websocket doesn't return an error if it's not closed gracefully. So we expect the heartbeat at some point to consider the connection down. Furthermore, while the connection is down sending the connection intents to the portal fails silently, so now we re-try the message until we get a response and built some race-condition protections in case we get multiple or stale responses.
This commit is contained in:
12
rust/Cargo.lock
generated
12
rust/Cargo.lock
generated
@@ -1827,6 +1827,7 @@ dependencies = [
|
||||
"swift-bridge",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite",
|
||||
"tracing",
|
||||
"url",
|
||||
@@ -3432,6 +3433,17 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.19.0"
|
||||
|
||||
@@ -4,21 +4,22 @@ use crate::messages::{Connect, ConnectionDetails, EgressMessages, InitClient, Me
|
||||
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
|
||||
use boringtun::x25519::StaticSecret;
|
||||
use libs_common::{
|
||||
control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic},
|
||||
control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic, Reference},
|
||||
messages::{Id, ResourceDescription},
|
||||
Callbacks, ControlSession, Error, Result,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use firezone_tunnel::{ControlSignal, Request, Tunnel};
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use tokio::sync::{mpsc::Receiver, Mutex};
|
||||
|
||||
#[async_trait]
|
||||
impl ControlSignal for ControlSignaler {
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
connected_gateway_ids: Vec<Id>,
|
||||
connected_gateway_ids: &[Id],
|
||||
reference: usize,
|
||||
) -> Result<()> {
|
||||
self.control_signal
|
||||
// It's easier if self is not mut
|
||||
@@ -26,11 +27,9 @@ impl ControlSignal for ControlSignaler {
|
||||
.send_with_ref(
|
||||
EgressMessages::PrepareConnection {
|
||||
resource_id: resource.id(),
|
||||
connected_gateway_ids,
|
||||
connected_gateway_ids: connected_gateway_ids.to_vec(),
|
||||
},
|
||||
// The resource id functions as the connection id since we can only have one connection
|
||||
// outgoing for each resource.
|
||||
resource.id(),
|
||||
reference,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
@@ -41,6 +40,7 @@ impl ControlSignal for ControlSignaler {
|
||||
pub struct ControlPlane<CB: Callbacks> {
|
||||
tunnel: Arc<Tunnel<ControlSignaler, CB>>,
|
||||
control_signaler: ControlSignaler,
|
||||
tunnel_init: Mutex<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -50,14 +50,17 @@ struct ControlSignaler {
|
||||
|
||||
impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn start(mut self, mut receiver: Receiver<MessageResult<Messages>>) -> Result<()> {
|
||||
async fn start(
|
||||
mut self,
|
||||
mut receiver: Receiver<(MessageResult<Messages>, Option<Reference>)>,
|
||||
) -> Result<()> {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(10));
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = receiver.recv() => {
|
||||
Some((msg, reference)) = receiver.recv() => {
|
||||
match msg {
|
||||
Ok(msg) => self.handle_message(msg).await?,
|
||||
Err(msg_reply) => self.handle_error(msg_reply).await,
|
||||
Ok(msg) => self.handle_message(msg, reference).await?,
|
||||
Err(err) => self.handle_error(err, reference).await,
|
||||
}
|
||||
},
|
||||
_ = interval.tick() => self.stats_event().await,
|
||||
@@ -75,16 +78,25 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
resources,
|
||||
}: InitClient,
|
||||
) -> Result<()> {
|
||||
if let Err(e) = self.tunnel.set_interface(&interface).await {
|
||||
tracing::error!(error = ?e, "Error initializing interface");
|
||||
Err(e)
|
||||
} else {
|
||||
for resource_description in resources {
|
||||
self.add_resource(resource_description).await;
|
||||
{
|
||||
let mut init = self.tunnel_init.lock().await;
|
||||
if !*init {
|
||||
if let Err(e) = self.tunnel.set_interface(&interface).await {
|
||||
tracing::error!(error = ?e, "Error initializing interface");
|
||||
return Err(e);
|
||||
} else {
|
||||
*init = true;
|
||||
tracing::info!("Firezoned Started!");
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Firezoned reinitializated");
|
||||
}
|
||||
tracing::info!("Firezoned Started!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
for resource_description in resources {
|
||||
self.add_resource(resource_description).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
@@ -137,12 +149,13 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
relays,
|
||||
..
|
||||
}: ConnectionDetails,
|
||||
reference: Option<Reference>,
|
||||
) {
|
||||
let tunnel = Arc::clone(&self.tunnel);
|
||||
let mut control_signaler = self.control_signaler.clone();
|
||||
tokio::spawn(async move {
|
||||
let err = match tunnel
|
||||
.request_connection(resource_id, gateway_id, relays)
|
||||
.request_connection(resource_id, gateway_id, relays, reference)
|
||||
.await
|
||||
{
|
||||
Ok(Request::NewConnection(connection_request)) => {
|
||||
@@ -185,11 +198,15 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub(super) async fn handle_message(&mut self, msg: Messages) -> Result<()> {
|
||||
pub(super) async fn handle_message(
|
||||
&mut self,
|
||||
msg: Messages,
|
||||
reference: Option<Reference>,
|
||||
) -> Result<()> {
|
||||
match msg {
|
||||
Messages::Init(init) => self.init(init).await?,
|
||||
Messages::ConnectionDetails(connection_details) => {
|
||||
self.connection_details(connection_details)
|
||||
self.connection_details(connection_details, reference)
|
||||
}
|
||||
Messages::Connect(connect) => self.connect(connect).await,
|
||||
Messages::ResourceAdded(resource) => self.add_resource(resource).await,
|
||||
@@ -200,9 +217,13 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub(super) async fn handle_error(&mut self, reply_error: ErrorReply) {
|
||||
pub(super) async fn handle_error(
|
||||
&mut self,
|
||||
reply_error: ErrorReply,
|
||||
reference: Option<Reference>,
|
||||
) {
|
||||
if matches!(reply_error.error, ErrorInfo::Offline) {
|
||||
match reply_error.reference {
|
||||
match reference {
|
||||
Some(reference) => {
|
||||
let Ok(id) = reference.parse() else {
|
||||
tracing::error!(
|
||||
@@ -240,7 +261,7 @@ impl<CB: Callbacks + 'static> ControlSession<Messages, CB> for ControlPlane<CB>
|
||||
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
|
||||
async fn start(
|
||||
private_key: StaticSecret,
|
||||
receiver: Receiver<MessageResult<Messages>>,
|
||||
receiver: Receiver<(MessageResult<Messages>, Option<Reference>)>,
|
||||
control_signal: PhoenixSenderWithTopic,
|
||||
callbacks: CB,
|
||||
) -> Result<()> {
|
||||
@@ -250,6 +271,7 @@ impl<CB: Callbacks + 'static> ControlSession<Messages, CB> for ControlPlane<CB>
|
||||
let control_plane = ControlPlane {
|
||||
tunnel,
|
||||
control_signaler,
|
||||
tunnel_init: Mutex::new(false),
|
||||
};
|
||||
|
||||
tokio::spawn(async move { control_plane.start(receiver).await });
|
||||
|
||||
@@ -31,6 +31,7 @@ rand = { version = "0.8", default-features = false, features = ["std"] }
|
||||
chrono = { workspace = true }
|
||||
parking_lot = "0.12"
|
||||
ring = "0.16"
|
||||
tokio-stream = { version = "0.1", features = ["time"] }
|
||||
|
||||
# Needed for Android logging until tracing is working
|
||||
log = "0.4"
|
||||
|
||||
@@ -11,9 +11,10 @@ use futures::{
|
||||
channel::mpsc::{channel, Receiver, Sender},
|
||||
TryStreamExt,
|
||||
};
|
||||
use futures_util::{Future, SinkExt, StreamExt};
|
||||
use futures_util::{Future, SinkExt, StreamExt, TryFutureExt};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio_stream::StreamExt as _;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{self, handshake::client::Request},
|
||||
@@ -24,6 +25,10 @@ use url::Url;
|
||||
use crate::{get_user_agent, Error, Result};
|
||||
|
||||
const CHANNEL_SIZE: usize = 1_000;
|
||||
const HEARTBEAT: Duration = Duration::from_secs(30);
|
||||
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35);
|
||||
|
||||
pub type Reference = String;
|
||||
|
||||
/// Main struct to interact with the control-protocol channel.
|
||||
///
|
||||
@@ -79,7 +84,7 @@ where
|
||||
I: DeserializeOwned,
|
||||
R: DeserializeOwned,
|
||||
M: From<I> + From<R>,
|
||||
F: Fn(MessageResult<M>) -> Fut,
|
||||
F: Fn(MessageResult<M>, Option<Reference>) -> Fut,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
/// Starts the tunnel with the parameters given in [Self::new].
|
||||
@@ -110,7 +115,10 @@ where
|
||||
handler, receiver, ..
|
||||
} = self;
|
||||
|
||||
let process_messages = read.try_for_each(|message| async {
|
||||
let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| {
|
||||
m.map_err(Error::from)?.map_err(Error::from)
|
||||
})
|
||||
.try_for_each(|message| async {
|
||||
Self::message_process(handler, message).await;
|
||||
Ok(())
|
||||
});
|
||||
@@ -141,13 +149,20 @@ where
|
||||
// Furthermore can this also happen if write errors out? *that* I'd assume is possible...
|
||||
// What option is left? write a new future to forward items.
|
||||
// For now we should never assume that an item arrived the portal because we sent it!
|
||||
let send_messages = receiver.map(Ok).forward(write);
|
||||
let send_messages = futures::StreamExt::map(receiver, Ok)
|
||||
.forward(write)
|
||||
.map_err(Error::from);
|
||||
|
||||
let phoenix_heartbeat = tokio::spawn(async move {
|
||||
let mut timer = tokio::time::interval(Duration::from_secs(30));
|
||||
let mut timer = tokio::time::interval(HEARTBEAT);
|
||||
loop {
|
||||
timer.tick().await;
|
||||
let Ok(_) = sender.send("phoenix", EgressControlMessage::Heartbeat(Empty {})).await else { break };
|
||||
let Ok(_) = sender
|
||||
.send("phoenix", EgressControlMessage::Heartbeat(Empty {}))
|
||||
.await
|
||||
else {
|
||||
break;
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
@@ -174,30 +189,28 @@ where
|
||||
match message.into_text() {
|
||||
Ok(m_str) => match serde_json::from_str::<PhoenixMessage<I, R>>(&m_str) {
|
||||
Ok(m) => match m.payload {
|
||||
Payload::Message(m) => handler(Ok(m.into())).await,
|
||||
Payload::Message(payload) => handler(Ok(payload.into()), m.reference).await,
|
||||
Payload::Reply(status) => match status {
|
||||
ReplyMessage::PhxReply(phx_reply) => match phx_reply {
|
||||
// TODO: Here we should pass error info to a subscriber
|
||||
PhxReply::Error(info) => {
|
||||
tracing::warn!("Portal error: {info:?}");
|
||||
handler(Err(ErrorReply {
|
||||
error: info,
|
||||
reference: m.reference,
|
||||
}))
|
||||
.await
|
||||
handler(Err(ErrorReply { error: info }), m.reference).await
|
||||
}
|
||||
PhxReply::Ok(reply) => match reply {
|
||||
OkReply::NoMessage(Empty {}) => {
|
||||
tracing::trace!("Phoenix status message")
|
||||
tracing::trace!(target: "phoenix_status", "Phoenix status message")
|
||||
}
|
||||
OkReply::Message(payload) => {
|
||||
handler(Ok(payload.into()), m.reference).await
|
||||
}
|
||||
OkReply::Message(m) => handler(Ok(m.into())).await,
|
||||
},
|
||||
},
|
||||
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
|
||||
},
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Error deserializing message {m_str}: {e:?}");
|
||||
tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e);
|
||||
}
|
||||
},
|
||||
_ => tracing::error!("Received message that is not text"),
|
||||
@@ -254,8 +267,6 @@ pub type MessageResult<M> = std::result::Result<M, ErrorReply>;
|
||||
pub struct ErrorReply {
|
||||
/// Information of the error
|
||||
pub error: ErrorInfo,
|
||||
/// Reference to the message that caused the error
|
||||
pub reference: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
|
||||
@@ -21,6 +21,9 @@ pub enum ConnlibError {
|
||||
/// Request error for websocket connection.
|
||||
#[error("Error forming request: {0}")]
|
||||
RequestError(#[from] tokio_tungstenite::tungstenite::http::Error),
|
||||
/// Websocket heartbeat timedout
|
||||
#[error("Websocket heartbeat timedout")]
|
||||
WebsocketTimeout(#[from] tokio_stream::Elapsed),
|
||||
/// Error during websocket connection.
|
||||
#[error("Portal connection error: {0}")]
|
||||
PortalConnectionError(#[from] tokio_tungstenite::tungstenite::error::Error),
|
||||
@@ -99,6 +102,12 @@ pub enum ConnlibError {
|
||||
/// A panic occurred with a non-string payload.
|
||||
#[error("Panicked with a non-string payload")]
|
||||
PanicNonStringPayload,
|
||||
/// Received connection details that might be stale
|
||||
#[error("Unexpected connection details")]
|
||||
UnexpectedConnectionDetails,
|
||||
/// Invalid phoenix channel reference
|
||||
#[error("Invalid phoenix channel reply reference")]
|
||||
InvalidReference,
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
||||
@@ -16,7 +16,7 @@ use tokio::{runtime::Runtime, sync::mpsc::Receiver};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic},
|
||||
control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic, Reference},
|
||||
messages::{Key, ResourceDescription},
|
||||
Error, Result,
|
||||
};
|
||||
@@ -33,7 +33,7 @@ pub trait ControlSession<T, CB: Callbacks> {
|
||||
/// Start control-plane with the given private-key in the background.
|
||||
async fn start(
|
||||
private_key: StaticSecret,
|
||||
receiver: Receiver<MessageResult<T>>,
|
||||
receiver: Receiver<(MessageResult<T>, Option<Reference>)>,
|
||||
control_signal: PhoenixSenderWithTopic,
|
||||
callbacks: CB,
|
||||
) -> Result<()>;
|
||||
@@ -292,11 +292,11 @@ where
|
||||
// to force queue ordering.
|
||||
let (control_plane_sender, control_plane_receiver) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg| {
|
||||
let mut connection = PhoenixChannel::<_, U, R, M>::new(connect_url, move |msg, reference| {
|
||||
let control_plane_sender = control_plane_sender.clone();
|
||||
async move {
|
||||
tracing::trace!("Received message: {msg:?}");
|
||||
if let Err(e) = control_plane_sender.send(msg).await {
|
||||
if let Err(e) = control_plane_sender.send((msg, reference)).await {
|
||||
tracing::warn!("Received a message after handler already closed: {e}. Probably message received during session clean up.");
|
||||
}
|
||||
}
|
||||
@@ -318,8 +318,8 @@ where
|
||||
tracing::debug!("Attempting connection to portal...");
|
||||
let result = connection.start(vec![topic.clone()], || exponential_backoff.reset()).await;
|
||||
tracing::warn!("Disconnected from the portal");
|
||||
if let Err(err) = &result {
|
||||
tracing::warn!("Portal connection error: {err}");
|
||||
if let Err(e) = &result {
|
||||
tracing::warn!(error = ?e, "Portal connection error");
|
||||
}
|
||||
if let Some(t) = exponential_backoff.next_backoff() {
|
||||
tracing::warn!("Error connecting to portal, retrying in {} seconds", t.as_secs());
|
||||
|
||||
@@ -4,7 +4,7 @@ use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
|
||||
use boringtun::x25519::StaticSecret;
|
||||
use firezone_tunnel::{ControlSignal, Tunnel};
|
||||
use libs_common::{
|
||||
control::{MessageResult, PhoenixSenderWithTopic},
|
||||
control::{MessageResult, PhoenixSenderWithTopic, Reference},
|
||||
messages::{Id, ResourceDescription},
|
||||
Callbacks, ControlSession, Result,
|
||||
};
|
||||
@@ -33,7 +33,8 @@ impl ControlSignal for ControlSignaler {
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
_connected_gateway_ids: Vec<Id>,
|
||||
_connected_gateway_ids: &[Id],
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
tracing::warn!("A message to network resource: {resource:?} was discarded, gateways aren't meant to be used as clients.");
|
||||
Ok(())
|
||||
@@ -42,11 +43,14 @@ impl ControlSignal for ControlSignaler {
|
||||
|
||||
impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
async fn start(mut self, mut receiver: Receiver<MessageResult<IngressMessages>>) -> Result<()> {
|
||||
async fn start(
|
||||
mut self,
|
||||
mut receiver: Receiver<(MessageResult<IngressMessages>, Option<Reference>)>,
|
||||
) -> Result<()> {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(10));
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = receiver.recv() => {
|
||||
Some((msg, _)) = receiver.recv() => {
|
||||
match msg {
|
||||
Ok(msg) => self.handle_message(msg).await?,
|
||||
Err(_msg_reply) => todo!(),
|
||||
@@ -144,7 +148,7 @@ impl<CB: Callbacks + 'static> ControlSession<IngressMessages, CB> for ControlPla
|
||||
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
|
||||
async fn start(
|
||||
private_key: StaticSecret,
|
||||
receiver: Receiver<MessageResult<IngressMessages>>,
|
||||
receiver: Receiver<(MessageResult<IngressMessages>, Option<Reference>)>,
|
||||
control_signal: PhoenixSenderWithTopic,
|
||||
callbacks: CB,
|
||||
) -> Result<()> {
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use tracing::instrument;
|
||||
|
||||
use libs_common::{
|
||||
control::Reference,
|
||||
messages::{Id, Key, Relay, RequestConnection, ResourceDescription, ReuseConnection},
|
||||
Callbacks, Error, Result,
|
||||
};
|
||||
@@ -255,16 +256,39 @@ where
|
||||
resource_id: Id,
|
||||
gateway_id: Id,
|
||||
relays: Vec<Relay>,
|
||||
reference: Option<Reference>,
|
||||
) -> Result<Request> {
|
||||
self.resources_gateways
|
||||
.lock()
|
||||
.insert(resource_id, gateway_id);
|
||||
tracing::trace!("Received gateways and relays for resource, requesting connection");
|
||||
let resource_description = self
|
||||
.resources
|
||||
.read()
|
||||
.get_by_id(&resource_id)
|
||||
.ok_or(Error::UnknownResource)?
|
||||
.clone();
|
||||
|
||||
let reference: usize = reference
|
||||
.ok_or(Error::InvalidReference)?
|
||||
.parse()
|
||||
.map_err(|_| Error::InvalidReference)?;
|
||||
{
|
||||
let mut awaiting_connections = self.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) = awaiting_connections.get_mut(&resource_id) else {
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
};
|
||||
awaiting_connection.response_recieved = true;
|
||||
if awaiting_connection.total_attemps != reference
|
||||
|| resource_description
|
||||
.ips()
|
||||
.iter()
|
||||
.any(|&ip| self.peers_by_ip.read().exact_match(ip).is_some())
|
||||
{
|
||||
return Err(Error::UnexpectedConnectionDetails);
|
||||
}
|
||||
}
|
||||
|
||||
self.resources_gateways
|
||||
.lock()
|
||||
.insert(resource_id, gateway_id);
|
||||
{
|
||||
let mut gateway_awaiting_connection = self.gateway_awaiting_connection.lock();
|
||||
if let Some(g) = gateway_awaiting_connection.get_mut(&gateway_id) {
|
||||
@@ -278,23 +302,38 @@ where
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
let peer = peers_by_ip
|
||||
.iter()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p))
|
||||
.cloned();
|
||||
if let Some(peer) = peer {
|
||||
for ip in resource_description.ips() {
|
||||
peer.add_allowed_ip(ip);
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
let found = {
|
||||
let mut peers_by_ip = self.peers_by_ip.write();
|
||||
let peer = peers_by_ip
|
||||
.iter()
|
||||
.find_map(|(_, p)| (p.conn_id == gateway_id).then_some(p))
|
||||
.cloned();
|
||||
if let Some(peer) = peer {
|
||||
for ip in resource_description.ips() {
|
||||
peer.add_allowed_ip(ip);
|
||||
peers_by_ip.insert(ip, Arc::clone(&peer));
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if found {
|
||||
self.awaiting_connection.lock().remove(&resource_id);
|
||||
return Ok(Request::ReuseConnection(ReuseConnection {
|
||||
resource_id,
|
||||
gateway_id,
|
||||
}));
|
||||
}
|
||||
}
|
||||
let peer_connection = self.initialize_peer_request(relays).await?;
|
||||
let peer_connection = {
|
||||
let peer_connection = Arc::new(self.initialize_peer_request(relays).await?);
|
||||
let mut peer_connections = self.peer_connections.lock();
|
||||
peer_connections.insert(gateway_id, Arc::clone(&peer_connection));
|
||||
peer_connection
|
||||
};
|
||||
|
||||
self.set_connection_state_update_initiator(&peer_connection, gateway_id, resource_id);
|
||||
|
||||
let data_channel = peer_connection.create_data_channel("data", None).await?;
|
||||
@@ -360,10 +399,6 @@ where
|
||||
.await
|
||||
.expect("Developer error: set_local_description was just called above");
|
||||
|
||||
self.peer_connections
|
||||
.lock()
|
||||
.insert(gateway_id, peer_connection);
|
||||
|
||||
Ok(Request::NewConnection(RequestConnection {
|
||||
resource_id,
|
||||
gateway_id,
|
||||
|
||||
@@ -29,12 +29,7 @@ use webrtc::{
|
||||
peer_connection::RTCPeerConnection,
|
||||
};
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
net::IpAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration};
|
||||
|
||||
use libs_common::{
|
||||
messages::{Id, Interface as InterfaceConfig, ResourceDescription},
|
||||
@@ -93,6 +88,8 @@ const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
|
||||
const HANDSHAKE_RATE_LIMIT: u64 = 100;
|
||||
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
|
||||
|
||||
const MAX_SIGNAL_CONNECTION_DELAY: Duration = Duration::from_secs(2);
|
||||
|
||||
/// Represent's the tunnel actual peer's config
|
||||
/// Obtained from libs_common's Peer
|
||||
#[derive(Clone)]
|
||||
@@ -125,10 +122,17 @@ pub trait ControlSignal {
|
||||
async fn signal_connection_to(
|
||||
&self,
|
||||
resource: &ResourceDescription,
|
||||
connected_gateway_ids: Vec<Id>,
|
||||
connected_gateway_ids: &[Id],
|
||||
reference: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
struct AwaitingConnectionDetails {
|
||||
pub total_attemps: usize,
|
||||
pub response_recieved: bool,
|
||||
}
|
||||
|
||||
// TODO: We should use newtypes for each kind of Id
|
||||
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
|
||||
/// to communicate between peers.
|
||||
@@ -143,7 +147,7 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
|
||||
public_key: PublicKey,
|
||||
peers_by_ip: RwLock<IpNetworkTable<Arc<Peer>>>,
|
||||
peer_connections: Mutex<HashMap<Id, Arc<RTCPeerConnection>>>,
|
||||
awaiting_connection: Mutex<HashSet<Id>>,
|
||||
awaiting_connection: Mutex<HashMap<Id, AwaitingConnectionDetails>>,
|
||||
gateway_awaiting_connection: Mutex<HashMap<Id, Vec<IpNetwork>>>,
|
||||
resources_gateways: Mutex<HashMap<Id, Id>>,
|
||||
webrtc_api: API,
|
||||
@@ -160,12 +164,13 @@ pub struct TunnelStats {
|
||||
public_key: String,
|
||||
peers_by_ip: HashMap<IpNetwork, PeerStats>,
|
||||
peer_connections: Vec<Id>,
|
||||
awaiting_connection: HashSet<Id>,
|
||||
gateway_awaiting_connection: HashMap<Id, Vec<IpNetwork>>,
|
||||
resource_gateways: HashMap<Id, Id>,
|
||||
dns_resources: HashMap<String, ResourceDescription>,
|
||||
network_resources: HashMap<IpNetwork, ResourceDescription>,
|
||||
gateway_public_keys: HashMap<Id, String>,
|
||||
|
||||
awaiting_connection: HashMap<Id, AwaitingConnectionDetails>,
|
||||
gateway_awaiting_connection: HashMap<Id, Vec<IpNetwork>>,
|
||||
}
|
||||
|
||||
impl<C, CB> Tunnel<C, CB>
|
||||
@@ -654,13 +659,13 @@ where
|
||||
// and we are finding another packet to the same address (otherwise we would just use peer_connections here)
|
||||
let mut awaiting_connection = dev.awaiting_connection.lock();
|
||||
let id = resource.id();
|
||||
if !awaiting_connection.contains(&id) {
|
||||
if awaiting_connection.get(&id).is_none() {
|
||||
tracing::trace!(
|
||||
message = "Found new intent to send packets to resource",
|
||||
resource_ip = %dst_addr
|
||||
);
|
||||
|
||||
awaiting_connection.insert(id);
|
||||
awaiting_connection.insert(id, Default::default());
|
||||
let dev = Arc::clone(&dev);
|
||||
|
||||
let mut connected_gateway_ids: Vec<_> = dev
|
||||
@@ -676,15 +681,38 @@ where
|
||||
message = "Currently connected gateways", gateways = ?connected_gateway_ids
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = dev
|
||||
.control_signaler
|
||||
.signal_connection_to(&resource, connected_gateway_ids)
|
||||
.await
|
||||
{
|
||||
// Not a deadlock because this is a different task
|
||||
dev.awaiting_connection.lock().remove(&id);
|
||||
tracing::error!(message = "couldn't start protocol for new connection to resource", error = ?e);
|
||||
let _ = dev.callbacks.on_error(&e);
|
||||
let mut interval =
|
||||
tokio::time::interval(MAX_SIGNAL_CONNECTION_DELAY);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let reference = {
|
||||
let mut awaiting_connections =
|
||||
dev.awaiting_connection.lock();
|
||||
let Some(awaiting_connection) =
|
||||
awaiting_connections.get_mut(&resource.id())
|
||||
else {
|
||||
break;
|
||||
};
|
||||
if awaiting_connection.response_recieved {
|
||||
break;
|
||||
}
|
||||
awaiting_connection.total_attemps += 1;
|
||||
awaiting_connection.total_attemps
|
||||
};
|
||||
if let Err(e) = dev
|
||||
.control_signaler
|
||||
.signal_connection_to(
|
||||
&resource,
|
||||
&connected_gateway_ids,
|
||||
reference,
|
||||
)
|
||||
.await
|
||||
{
|
||||
// Not a deadlock because this is a different task
|
||||
dev.awaiting_connection.lock().remove(&id);
|
||||
tracing::error!(message = "couldn't start protocol for new connection to resource", error = ?e);
|
||||
let _ = dev.callbacks.on_error(&e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user