From c817473aefa7f0fba44cabc4bd6baa4748af2721 Mon Sep 17 00:00:00 2001 From: Gabi Date: Thu, 6 Jul 2023 15:47:01 -0300 Subject: [PATCH] Feat/connlib handle error messages (#1735) With this PR we handle in the client an error message due to gateway/relay although rate limiting is needed. Waiting for #1729 to be merged. --- docker-compose.yml | 2 - rust/connlib/libs/client/src/control.rs | 60 ++++++++++++++++++++---- rust/connlib/libs/client/src/messages.rs | 2 +- rust/connlib/libs/common/src/control.rs | 50 +++++++++++++++----- rust/connlib/libs/common/src/session.rs | 4 +- rust/connlib/libs/gateway/src/control.rs | 13 +++-- rust/connlib/libs/tunnel/src/lib.rs | 4 +- 7 files changed, 102 insertions(+), 33 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 55541669b..1374877b9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -135,8 +135,6 @@ services: depends_on: api: condition: 'service_healthy' - relay: - condition: 'service_healthy' networks: app: ipv4_address: 172.28.0.100 diff --git a/rust/connlib/libs/client/src/control.rs b/rust/connlib/libs/client/src/control.rs index 3e19bb610..5462f3ab1 100644 --- a/rust/connlib/libs/client/src/control.rs +++ b/rust/connlib/libs/client/src/control.rs @@ -3,10 +3,10 @@ use std::{sync::Arc, time::Duration}; use crate::messages::{Connect, EgressMessages, InitClient, Messages, Relays}; use boringtun::x25519::StaticSecret; use libs_common::{ - control::PhoenixSenderWithTopic, - error_type::ErrorType::{Fatal, Recoverable}, + control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic}, + error_type::ErrorType::{self, Fatal, Recoverable}, messages::{Id, ResourceDescription}, - Callbacks, ControlSession, Result, + Callbacks, ControlSession, Error, Result, }; use async_trait::async_trait; @@ -19,9 +19,14 @@ impl ControlSignal for ControlSignaler { self.control_signal // It's easier if self is not mut .clone() - .send(EgressMessages::ListRelays { - resource_id: resource.id(), - }) + .send_with_ref( + EgressMessages::ListRelays { + resource_id: resource.id(), + }, + // The resource id functions as the connection id since we can only have one connection + // outgoing for each resource. + resource.id(), + ) .await?; Ok(()) } @@ -40,11 +45,16 @@ struct ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver) { + async fn start(mut self, mut receiver: Receiver>) { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { - Some(msg) = receiver.recv() => self.handle_message(msg).await, + Some(msg) = receiver.recv() => { + match msg { + Ok(msg) => self.handle_message(msg).await, + Err(msg_reply) => self.handle_error(msg_reply).await, + } + }, _ = interval.tick() => self.stats_event().await, else => break } @@ -126,7 +136,10 @@ impl ControlPlane { if let Err(err) = control_signaler .control_signal // TODO: create a reference number and keep track for the response - .send_with_ref(EgressMessages::RequestConnection(connection_request), 0) + .send_with_ref( + EgressMessages::RequestConnection(connection_request), + resource_id, + ) .await { tunnel.cleanup_connection(resource_id); @@ -153,6 +166,33 @@ impl ControlPlane { } } + #[tracing::instrument(level = "trace", skip(self))] + pub(super) async fn handle_error(&mut self, reply_error: ErrorReply) { + if matches!(reply_error.error, ErrorInfo::Offline) { + match reply_error.reference { + Some(reference) => { + let Ok(id) = reference.parse() else { + tracing::error!( + "An offline error came back with a reference to a non-valid resource id" + ); + self.tunnel.callbacks().on_error(&Error::ControlProtocolError, ErrorType::Recoverable); + return; + }; + // TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection + self.tunnel.cleanup_connection(id); + } + None => { + tracing::error!( + "An offline portal error came without a reference that originated the error" + ); + self.tunnel + .callbacks() + .on_error(&Error::ControlProtocolError, ErrorType::Recoverable); + } + } + } + } + #[tracing::instrument(level = "trace", skip(self))] pub(super) async fn stats_event(&mut self) { // TODO @@ -164,7 +204,7 @@ impl ControlSession for ControlPlane #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, - receiver: Receiver, + receiver: Receiver>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()> { diff --git a/rust/connlib/libs/client/src/messages.rs b/rust/connlib/libs/client/src/messages.rs index 29f454e8b..cb6f64233 100644 --- a/rust/connlib/libs/client/src/messages.rs +++ b/rust/connlib/libs/client/src/messages.rs @@ -128,7 +128,7 @@ mod test { #[test] fn connection_ready_deserialization() { let message = r#"{ - "ref": 0, + "ref": "0", "topic": "device", "event": "phx_reply", "payload": { diff --git a/rust/connlib/libs/common/src/control.rs b/rust/connlib/libs/common/src/control.rs index bdb68fd0a..d1d45d316 100644 --- a/rust/connlib/libs/common/src/control.rs +++ b/rust/connlib/libs/common/src/control.rs @@ -68,7 +68,6 @@ fn make_request(uri: &Url) -> Result { .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", key) - // TODO: Get OS Info here (os_info crate) .header("User-Agent", get_user_agent()) .uri(uri.as_str()) .body(())?; @@ -80,7 +79,7 @@ where I: DeserializeOwned, R: DeserializeOwned, M: From + From, - F: Fn(M) -> Fut, + F: Fn(MessageResult) -> Fut, Fut: Future + Send + 'static, { /// Starts the tunnel with the parameters given in [Self::new]. @@ -169,16 +168,23 @@ where match message.into_text() { Ok(m_str) => match serde_json::from_str::>(&m_str) { Ok(m) => match m.payload { - Payload::Message(m) => handler(m.into()).await, + Payload::Message(m) => handler(Ok(m.into())).await, Payload::Reply(status) => match status { ReplyMessage::PhxReply(phx_reply) => match phx_reply { // TODO: Here we should pass error info to a subscriber - PhxReply::Error(info) => tracing::error!("Portal error: {info:?}"), + PhxReply::Error(info) => { + tracing::warn!("Portal error: {info:?}"); + handler(Err(ErrorReply { + error: info, + reference: m.reference, + })) + .await + } PhxReply::Ok(reply) => match reply { OkReply::NoMessage(Empty {}) => { tracing::trace!("Phoenix status message") } - OkReply::Message(m) => handler(m.into()).await, + OkReply::Message(m) => handler(Ok(m.into())).await, }, }, ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"), @@ -232,6 +238,20 @@ where } } +/// A result type that is used to communicate to the client/gateway +/// control loop the message received. +pub type MessageResult = std::result::Result; + +/// This struct holds info about an error reply which will be passed +/// to connlib's control plane. +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +pub struct ErrorReply { + /// Information of the error + pub error: ErrorInfo, + /// Reference to the message that caused the error + pub reference: Option, +} + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] #[serde(untagged)] enum Payload { @@ -248,11 +268,11 @@ pub struct PhoenixMessage { #[serde(flatten)] payload: Payload, #[serde(rename = "ref")] - reference: Option, + reference: Option, } impl PhoenixMessage { - pub fn new(topic: impl Into, payload: T, reference: Option) -> Self { + pub fn new(topic: impl Into, payload: T, reference: Option) -> Self { Self { topic: topic.into(), payload: Payload::Message(payload), @@ -298,9 +318,10 @@ enum OkReply { NoMessage(Empty), } +/// This represents the info we have about the error #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[serde(rename_all = "snake_case")] -enum ErrorInfo { +pub enum ErrorInfo { Reason(String), Offline, } @@ -342,7 +363,11 @@ impl PhoenixSenderWithTopic { /// Sends a message to the associated topic using a [PhoenixSender] also setting the ref /// /// See [PhoenixSender::send] - pub async fn send_with_ref(&mut self, payload: impl Serialize, reference: i32) -> Result<()> { + pub async fn send_with_ref( + &mut self, + payload: impl Serialize, + reference: impl ToString, + ) -> Result<()> { self.phoenix_sender .send_with_ref(&self.topic, payload, reference) .await @@ -354,7 +379,7 @@ impl PhoenixSender { &mut self, topic: impl Into, payload: impl Serialize, - reference: Option, + reference: Option, ) -> Result<()> { // We don't care about the reply type when serializing let str = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, reference))?; @@ -381,9 +406,10 @@ impl PhoenixSender { &mut self, topic: impl Into, payload: impl Serialize, - reference: i32, + reference: impl ToString, ) -> Result<()> { - self.send_internal(topic, payload, Some(reference)).await + self.send_internal(topic, payload, Some(reference.to_string())) + .await } /// Join a phoenix topic, meaning that after this method is invoked [PhoenixChannel] will diff --git a/rust/connlib/libs/common/src/session.rs b/rust/connlib/libs/common/src/session.rs index 5d4af8c5e..d352c5f0c 100644 --- a/rust/connlib/libs/common/src/session.rs +++ b/rust/connlib/libs/common/src/session.rs @@ -13,7 +13,7 @@ use url::Url; use uuid::Uuid; use crate::{ - control::{PhoenixChannel, PhoenixSenderWithTopic}, + control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic}, error_type::ErrorType, messages::{Key, ResourceDescription, ResourceDescriptionCidr}, Error, Result, @@ -26,7 +26,7 @@ pub trait ControlSession { /// Start control-plane with the given private-key in the background. async fn start( private_key: StaticSecret, - receiver: Receiver, + receiver: Receiver>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()>; diff --git a/rust/connlib/libs/gateway/src/control.rs b/rust/connlib/libs/gateway/src/control.rs index 12099d077..ecd8bc6e3 100644 --- a/rust/connlib/libs/gateway/src/control.rs +++ b/rust/connlib/libs/gateway/src/control.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use boringtun::x25519::StaticSecret; use firezone_tunnel::{ControlSignal, Tunnel}; use libs_common::{ - control::PhoenixSenderWithTopic, + control::{MessageResult, PhoenixSenderWithTopic}, error_type::ErrorType::{Fatal, Recoverable}, messages::ResourceDescription, Callbacks, ControlSession, Result, @@ -36,11 +36,16 @@ impl ControlSignal for ControlSignaler { impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] - async fn start(mut self, mut receiver: Receiver) { + async fn start(mut self, mut receiver: Receiver>) { let mut interval = tokio::time::interval(Duration::from_secs(10)); loop { tokio::select! { - Some(msg) = receiver.recv() => self.handle_message(msg).await, + Some(msg) = receiver.recv() => { + match msg { + Ok(msg) => self.handle_message(msg).await, + Err(_msg_reply) => todo!(), + } + }, _ = interval.tick() => self.stats_event().await, else => break } @@ -123,7 +128,7 @@ impl ControlSession for ControlPla #[tracing::instrument(level = "trace", skip(private_key, callbacks))] async fn start( private_key: StaticSecret, - receiver: Receiver, + receiver: Receiver>, control_signal: PhoenixSenderWithTopic, callbacks: CB, ) -> Result<()> { diff --git a/rust/connlib/libs/tunnel/src/lib.rs b/rust/connlib/libs/tunnel/src/lib.rs index 45a153ada..51618e0b5 100644 --- a/rust/connlib/libs/tunnel/src/lib.rs +++ b/rust/connlib/libs/tunnel/src/lib.rs @@ -88,7 +88,7 @@ mod device_channel; mod device_channel; const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1); -const REFRESH_PEERS_TIEMRS_INTERVAL: Duration = Duration::from_secs(1); +const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1); // Note: Taken from boringtun const HANDSHAKE_RATE_LIMIT: u64 = 100; @@ -282,7 +282,7 @@ where let tunnel = self.clone(); tokio::spawn(async move { - let mut interval = tokio::time::interval(REFRESH_PEERS_TIEMRS_INTERVAL); + let mut interval = tokio::time::interval(REFRESH_PEERS_TIMERS_INTERVAL); interval.set_missed_tick_behavior(MissedTickBehavior::Delay); let mut dst_buf = [0u8; MAX_UDP_SIZE];