mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
fix(connlib): handle expiration messages correctly (#3292)
While working on #3288 I saw a few messages that we don't explicitly handle from the portal. This PR changes it so that we handle them correctly and we don't just depend on coincidental behavior..
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use async_compression::tokio::bufread::GzipEncoder;
|
||||
use connlib_shared::control::ChannelError;
|
||||
use connlib_shared::control::KnownError;
|
||||
use connlib_shared::control::Reason;
|
||||
use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer};
|
||||
@@ -12,7 +13,7 @@ use crate::messages::{
|
||||
GatewayIceCandidates, InitClient, Messages,
|
||||
};
|
||||
use connlib_shared::{
|
||||
control::{ErrorInfo, ErrorReply, PhoenixSenderWithTopic, Reference},
|
||||
control::{ErrorInfo, PhoenixSenderWithTopic, Reference},
|
||||
messages::{GatewayId, ResourceDescription, ResourceId},
|
||||
Callbacks,
|
||||
Error::{self},
|
||||
@@ -271,12 +272,12 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn handle_error(
|
||||
&mut self,
|
||||
reply_error: ErrorReply,
|
||||
reply_error: ChannelError,
|
||||
reference: Option<Reference>,
|
||||
topic: String,
|
||||
) -> Result<()> {
|
||||
match (reply_error.error, reference) {
|
||||
(ErrorInfo::Offline, Some(reference)) => {
|
||||
match (reply_error, reference) {
|
||||
(ChannelError::ErrorReply(ErrorInfo::Offline), Some(reference)) => {
|
||||
let Ok(resource_id) = reference.parse::<ResourceId>() else {
|
||||
tracing::warn!("The portal responded with an Offline error. Is the Resource associated with any online Gateways or Relays?");
|
||||
return Ok(());
|
||||
@@ -284,12 +285,23 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
|
||||
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
|
||||
self.tunnel.cleanup_connection(resource_id);
|
||||
}
|
||||
(ErrorInfo::Reason(Reason::Known(KnownError::UnmatchedTopic)), _) => {
|
||||
(
|
||||
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
|
||||
KnownError::UnmatchedTopic,
|
||||
))),
|
||||
_,
|
||||
) => {
|
||||
if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await {
|
||||
tracing::debug!(err = ?e, "couldn't join topic: {e:#?}");
|
||||
}
|
||||
}
|
||||
(ErrorInfo::Reason(Reason::Known(KnownError::TokenExpired)), _) => {
|
||||
(
|
||||
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
|
||||
KnownError::TokenExpired,
|
||||
))),
|
||||
_,
|
||||
)
|
||||
| (ChannelError::ErrorMsg(Error::TokenExpired), _) => {
|
||||
return Err(Error::TokenExpired);
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -149,10 +149,7 @@ where
|
||||
let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| {
|
||||
m.map_err(Error::from)?.map_err(Error::from)
|
||||
})
|
||||
.try_for_each(|message| async {
|
||||
Self::message_process(handler, message).await;
|
||||
Ok(())
|
||||
});
|
||||
.try_for_each(|message| async { Self::message_process(handler, message).await });
|
||||
|
||||
// Would we like to do write.send_all(futures::stream(Message::text(...))) ?
|
||||
// yes.
|
||||
@@ -214,7 +211,7 @@ where
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(handler))]
|
||||
async fn message_process(handler: &F, message: tungstenite::Message) {
|
||||
async fn message_process(handler: &F, message: tungstenite::Message) -> Result<()> {
|
||||
tracing::trace!("{message:?}");
|
||||
|
||||
match message.into_text() {
|
||||
@@ -228,7 +225,8 @@ where
|
||||
// TODO: Here we should pass error info to a subscriber
|
||||
PhxReply::Error(info) => {
|
||||
tracing::debug!("Portal error: {info:?}");
|
||||
handler(Err(ErrorReply { error: info }), m.reference, m.topic).await
|
||||
handler(Err(ChannelError::ErrorReply(info)), m.reference, m.topic)
|
||||
.await
|
||||
}
|
||||
PhxReply::Ok(reply) => match reply {
|
||||
OkReply::NoMessage(Empty {}) => {
|
||||
@@ -241,6 +239,17 @@ where
|
||||
},
|
||||
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
|
||||
},
|
||||
Payload::ControlMessage(ControlMessage::PhxClose(_)) => {
|
||||
return Err(Error::ClosedByPortal)
|
||||
}
|
||||
Payload::ControlMessage(ControlMessage::TokenExpired(_)) => {
|
||||
handler(
|
||||
Err(ChannelError::ErrorMsg(Error::TokenExpired)),
|
||||
m.reference,
|
||||
m.topic,
|
||||
)
|
||||
.await
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e);
|
||||
@@ -248,6 +257,8 @@ where
|
||||
},
|
||||
_ => tracing::error!("Received message that is not text"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal.
|
||||
@@ -297,14 +308,12 @@ where
|
||||
|
||||
/// A result type that is used to communicate to the client/gateway
|
||||
/// control loop the message received.
|
||||
pub type MessageResult<M> = std::result::Result<M, ErrorReply>;
|
||||
pub type MessageResult<M> = std::result::Result<M, ChannelError>;
|
||||
|
||||
/// This struct holds info about an error reply which will be passed
|
||||
/// to connlib's control plane.
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
pub struct ErrorReply {
|
||||
/// Information of the error
|
||||
pub error: ErrorInfo,
|
||||
#[derive(Debug)]
|
||||
pub enum ChannelError {
|
||||
ErrorReply(ErrorInfo),
|
||||
ErrorMsg(Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
@@ -314,9 +323,17 @@ enum Payload<T, R> {
|
||||
// but that makes everything even more convoluted!
|
||||
// and we need to think how to make this whole mess less convoluted.
|
||||
Reply(ReplyMessage<R>),
|
||||
ControlMessage(ControlMessage),
|
||||
Message(T),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
enum ControlMessage {
|
||||
PhxClose(Empty),
|
||||
TokenExpired(Empty),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
|
||||
pub struct PhoenixMessage<T, R> {
|
||||
// TODO: we should use a newtype pattern for topics
|
||||
@@ -526,8 +543,8 @@ impl PhoenixSender {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::control::{
|
||||
ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, ReplyMessage::PhxReply,
|
||||
UnknownError,
|
||||
ControlMessage, Empty, ErrorInfo, KnownError, Payload, PhxReply::Error, Reason,
|
||||
ReplyMessage::PhxReply, UnknownError,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -552,6 +569,37 @@ mod tests {
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn phx_close() {
|
||||
let actual_reply = r#"
|
||||
{
|
||||
"event": "phx_close",
|
||||
"ref": null,
|
||||
"topic": "client",
|
||||
"payload": {}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::PhxClose(Empty {}));
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_expired() {
|
||||
let actual_reply = r#"
|
||||
{
|
||||
"event": "token_expired",
|
||||
"ref": null,
|
||||
"topic": "client",
|
||||
"payload": {}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply =
|
||||
Payload::<(), ()>::ControlMessage(ControlMessage::TokenExpired(Empty {}));
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_error_reply() {
|
||||
let actual_reply = r#"
|
||||
|
||||
@@ -153,6 +153,8 @@ pub enum ConnlibError {
|
||||
TokenExpired,
|
||||
#[error("Too many concurrent gateway connection requests")]
|
||||
TooManyConnectionRequests,
|
||||
#[error("Channel connection closed by portal")]
|
||||
ClosedByPortal,
|
||||
}
|
||||
|
||||
impl ConnlibError {
|
||||
|
||||
Reference in New Issue
Block a user