fix(connlib): handle expiration messages correctly (#3292)

While working on #3288 I saw a few messages that we don't explicitly
handle from the portal.

This PR changes it so that we handle them correctly and we don't just
depend on coincidental behavior..
This commit is contained in:
Gabi
2024-01-18 15:08:43 -03:00
committed by GitHub
parent 32450c89d3
commit 2277d92c88
3 changed files with 83 additions and 21 deletions

View File

@@ -1,4 +1,5 @@
use async_compression::tokio::bufread::GzipEncoder;
use connlib_shared::control::ChannelError;
use connlib_shared::control::KnownError;
use connlib_shared::control::Reason;
use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer};
@@ -12,7 +13,7 @@ use crate::messages::{
GatewayIceCandidates, InitClient, Messages,
};
use connlib_shared::{
control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference},
control::{ErrorInfo, PhoenixSenderWithTopic, Reference},
messages::{GatewayId, ResourceDescription, ResourceId},
Callbacks,
Error::{self},
@@ -271,12 +272,12 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
pub async fn handle_error(
&mut self,
reply_error: ErrorReply,
reply_error: ChannelError,
reference: Option<Reference>,
topic: String,
) -> Result<()> {
match (reply_error.error, reference) {
(ErrorInfo::Offline, Some(reference)) => {
match (reply_error, reference) {
(ChannelError::ErrorReply(ErrorInfo::Offline), Some(reference)) => {
let Ok(resource_id) = reference.parse::<ResourceId>() else {
tracing::warn!("The portal responded with an Offline error. Is the Resource associated with any online Gateways or Relays?");
return Ok(());
@@ -284,12 +285,23 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
self.tunnel.cleanup_connection(resource_id);
}
(ErrorInfo::Reason(Reason::Known(KnownError::UnmatchedTopic)), _) => {
(
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
KnownError::UnmatchedTopic,
))),
_,
) => {
if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await {
tracing::debug!(err = ?e, "couldn't join topic: {e:#?}");
}
}
(ErrorInfo::Reason(Reason::Known(KnownError::TokenExpired)), _) => {
(
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
KnownError::TokenExpired,
))),
_,
)
| (ChannelError::ErrorMsg(Error::TokenExpired), _) => {
return Err(Error::TokenExpired);
}
_ => {}

View File

@@ -149,10 +149,7 @@ where
let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| {
m.map_err(Error::from)?.map_err(Error::from)
})
.try_for_each(|message| async {
Self::message_process(handler, message).await;
Ok(())
});
.try_for_each(|message| async { Self::message_process(handler, message).await });
// Would we like to do write.send_all(futures::stream(Message::text(...))) ?
// yes.
@@ -214,7 +211,7 @@ where
}
#[tracing::instrument(level = "trace", skip(handler))]
async fn message_process(handler: &F, message: tungstenite::Message) {
async fn message_process(handler: &F, message: tungstenite::Message) -> Result<()> {
tracing::trace!("{message:?}");
match message.into_text() {
@@ -228,7 +225,8 @@ where
// TODO: Here we should pass error info to a subscriber
PhxReply::Error(info) => {
tracing::debug!("Portal error: {info:?}");
handler(Err(ErrorReply { error: info }), m.reference, m.topic).await
handler(Err(ChannelError::ErrorReply(info)), m.reference, m.topic)
.await
}
PhxReply::Ok(reply) => match reply {
OkReply::NoMessage(Empty {}) => {
@@ -241,6 +239,17 @@ where
},
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
},
Payload::ControlMessage(ControlMessage::PhxClose(_)) => {
return Err(Error::ClosedByPortal)
}
Payload::ControlMessage(ControlMessage::TokenExpired(_)) => {
handler(
Err(ChannelError::ErrorMsg(Error::TokenExpired)),
m.reference,
m.topic,
)
.await
}
},
Err(e) => {
tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e);
@@ -248,6 +257,8 @@ where
},
_ => tracing::error!("Received message that is not text"),
}
Ok(())
}
/// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal.
@@ -297,14 +308,12 @@ where
/// A result type that is used to communicate to the client/gateway
/// control loop the message received.
pub type MessageResult<M> = std::result::Result<M, ErrorReply>;
pub type MessageResult<M> = std::result::Result<M, ChannelError>;
/// This struct holds info about an error reply which will be passed
/// to connlib's control plane.
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
pub struct ErrorReply {
/// Information of the error
pub error: ErrorInfo,
#[derive(Debug)]
pub enum ChannelError {
ErrorReply(ErrorInfo),
ErrorMsg(Error),
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
@@ -314,9 +323,17 @@ enum Payload<T, R> {
// but that makes everything even more convoluted!
// and we need to think how to make this whole mess less convoluted.
Reply(ReplyMessage<R>),
ControlMessage(ControlMessage),
Message(T),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ControlMessage {
PhxClose(Empty),
TokenExpired(Empty),
}
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
// TODO: we should use a newtype pattern for topics
@@ -526,8 +543,8 @@ impl PhoenixSender {
#[cfg(test)]
mod tests {
use crate::control::{
ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, ReplyMessage::PhxReply,
UnknownError,
ControlMessage, Empty, ErrorInfo, KnownError, Payload, PhxReply::Error, Reason,
ReplyMessage::PhxReply, UnknownError,
};
#[test]
@@ -552,6 +569,37 @@ mod tests {
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn phx_close() {
let actual_reply = r#"
{
"event": "phx_close",
"ref": null,
"topic": "client",
"payload": {}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::PhxClose(Empty {}));
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn token_expired() {
let actual_reply = r#"
{
"event": "token_expired",
"ref": null,
"topic": "client",
"payload": {}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply =
Payload::<(), ()>::ControlMessage(ControlMessage::TokenExpired(Empty {}));
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn unexpected_error_reply() {
let actual_reply = r#"

View File

@@ -153,6 +153,8 @@ pub enum ConnlibError {
TokenExpired,
#[error("Too many concurrent gateway connection requests")]
TooManyConnectionRequests,
#[error("Channel connection closed by portal")]
ClosedByPortal,
}
impl ConnlibError {