From 2277d92c88a41761e464c91bca8960c8cb6f1387 Mon Sep 17 00:00:00 2001 From: Gabi Date: Thu, 18 Jan 2024 15:08:43 -0300 Subject: [PATCH] fix(connlib): handle expiration messages correctly (#3292) While working on #3288 I saw a few messages that we don't explicitly handle from the portal. This PR changes it so that we handle them correctly and we don't just depend on coincidental behavior.. --- rust/connlib/clients/shared/src/control.rs | 24 +++++-- rust/connlib/shared/src/control.rs | 78 +++++++++++++++++----- rust/connlib/shared/src/error.rs | 2 + 3 files changed, 83 insertions(+), 21 deletions(-) diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index 09eb16ef9..b631fedc7 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -1,4 +1,5 @@ use async_compression::tokio::bufread::GzipEncoder; +use connlib_shared::control::ChannelError; use connlib_shared::control::KnownError; use connlib_shared::control::Reason; use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer}; @@ -12,7 +13,7 @@ use crate::messages::{ GatewayIceCandidates, InitClient, Messages, }; use connlib_shared::{ - control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference}, + control::{ErrorInfo, PhoenixSenderWithTopic, Reference}, messages::{GatewayId, ResourceDescription, ResourceId}, Callbacks, Error::{self}, @@ -271,12 +272,12 @@ impl ControlPlane { #[tracing::instrument(level = "trace", skip(self))] pub async fn handle_error( &mut self, - reply_error: ErrorReply, + reply_error: ChannelError, reference: Option, topic: String, ) -> Result<()> { - match (reply_error.error, reference) { - (ErrorInfo::Offline, Some(reference)) => { + match (reply_error, reference) { + (ChannelError::ErrorReply(ErrorInfo::Offline), Some(reference)) => { let Ok(resource_id) = reference.parse::() else { tracing::warn!("The portal responded with an Offline error. Is the Resource associated with any online Gateways or Relays?"); return Ok(()); @@ -284,12 +285,23 @@ impl ControlPlane { // TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection self.tunnel.cleanup_connection(resource_id); } - (ErrorInfo::Reason(Reason::Known(KnownError::UnmatchedTopic)), _) => { + ( + ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known( + KnownError::UnmatchedTopic, + ))), + _, + ) => { if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await { tracing::debug!(err = ?e, "couldn't join topic: {e:#?}"); } } - (ErrorInfo::Reason(Reason::Known(KnownError::TokenExpired)), _) => { + ( + ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known( + KnownError::TokenExpired, + ))), + _, + ) + | (ChannelError::ErrorMsg(Error::TokenExpired), _) => { return Err(Error::TokenExpired); } _ => {} diff --git a/rust/connlib/shared/src/control.rs b/rust/connlib/shared/src/control.rs index 89b2b22a2..e60ca5302 100644 --- a/rust/connlib/shared/src/control.rs +++ b/rust/connlib/shared/src/control.rs @@ -149,10 +149,7 @@ where let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| { m.map_err(Error::from)?.map_err(Error::from) }) - .try_for_each(|message| async { - Self::message_process(handler, message).await; - Ok(()) - }); + .try_for_each(|message| async { Self::message_process(handler, message).await }); // Would we like to do write.send_all(futures::stream(Message::text(...))) ? // yes. @@ -214,7 +211,7 @@ where } #[tracing::instrument(level = "trace", skip(handler))] - async fn message_process(handler: &F, message: tungstenite::Message) { + async fn message_process(handler: &F, message: tungstenite::Message) -> Result<()> { tracing::trace!("{message:?}"); match message.into_text() { @@ -228,7 +225,8 @@ where // TODO: Here we should pass error info to a subscriber PhxReply::Error(info) => { tracing::debug!("Portal error: {info:?}"); - handler(Err(ErrorReply { error: info }), m.reference, m.topic).await + handler(Err(ChannelError::ErrorReply(info)), m.reference, m.topic) + .await } PhxReply::Ok(reply) => match reply { OkReply::NoMessage(Empty {}) => { @@ -241,6 +239,17 @@ where }, ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"), }, + Payload::ControlMessage(ControlMessage::PhxClose(_)) => { + return Err(Error::ClosedByPortal) + } + Payload::ControlMessage(ControlMessage::TokenExpired(_)) => { + handler( + Err(ChannelError::ErrorMsg(Error::TokenExpired)), + m.reference, + m.topic, + ) + .await + } }, Err(e) => { tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e); @@ -248,6 +257,8 @@ where }, _ => tracing::error!("Received message that is not text"), } + + Ok(()) } /// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal. @@ -297,14 +308,12 @@ where /// A result type that is used to communicate to the client/gateway /// control loop the message received. -pub type MessageResult = std::result::Result; +pub type MessageResult = std::result::Result; -/// This struct holds info about an error reply which will be passed -/// to connlib's control plane. -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] -pub struct ErrorReply { - /// Information of the error - pub error: ErrorInfo, +#[derive(Debug)] +pub enum ChannelError { + ErrorReply(ErrorInfo), + ErrorMsg(Error), } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] @@ -314,9 +323,17 @@ enum Payload { // but that makes everything even more convoluted! // and we need to think how to make this whole mess less convoluted. Reply(ReplyMessage), + ControlMessage(ControlMessage), Message(T), } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case", tag = "event", content = "payload")] +enum ControlMessage { + PhxClose(Empty), + TokenExpired(Empty), +} + #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics @@ -526,8 +543,8 @@ impl PhoenixSender { #[cfg(test)] mod tests { use crate::control::{ - ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, ReplyMessage::PhxReply, - UnknownError, + ControlMessage, Empty, ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, + ReplyMessage::PhxReply, UnknownError, }; #[test] @@ -552,6 +569,37 @@ mod tests { assert_eq!(actual_reply, expected_reply); } + #[test] + fn phx_close() { + let actual_reply = r#" + { + "event": "phx_close", + "ref": null, + "topic": "client", + "payload": {} + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::PhxClose(Empty {})); + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn token_expired() { + let actual_reply = r#" + { + "event": "token_expired", + "ref": null, + "topic": "client", + "payload": {} + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = + Payload::<(), ()>::ControlMessage(ControlMessage::TokenExpired(Empty {})); + assert_eq!(actual_reply, expected_reply); + } + #[test] fn unexpected_error_reply() { let actual_reply = r#" diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 761edfde6..5a2aeffcd 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -153,6 +153,8 @@ pub enum ConnlibError { TokenExpired, #[error("Too many concurrent gateway connection requests")] TooManyConnectionRequests, + #[error("Channel connection closed by portal")] + ClosedByPortal, } impl ConnlibError {