refactor(portal): unify format of error payloads in websocket connection (#3697)

Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Andrew Dryga
2024-02-28 17:06:52 -06:00
committed by GitHub
parent cea7784730
commit bfe1fb0ff4
10 changed files with 343 additions and 269 deletions

View File

@@ -39,7 +39,7 @@ defmodule API.Client.Channel do
Process.send_after(self(), :token_expired, expires_in)
{:ok, socket}
else
{:error, %{"reason" => "token_expired"}}
{:error, %{reason: :token_expired}}
end
end
@@ -107,7 +107,7 @@ defmodule API.Client.Channel do
OpenTelemetry.Tracer.set_current_span(socket.assigns.opentelemetry_span_ctx)
OpenTelemetry.Tracer.with_span "client.token_expired" do
push(socket, "disconnect", %{"reason" => "token_expired"})
push(socket, "disconnect", %{reason: :token_expired})
{:stop, {:shutdown, :token_expired}, socket}
end
end
@@ -115,7 +115,7 @@ defmodule API.Client.Channel do
# This message is sent using Clients.broadcast_to_client/1 eg. when the client is deleted
def handle_info("disconnect", socket) do
OpenTelemetry.Tracer.with_span "client.disconnect" do
push(socket, "disconnect", %{"reason" => "token_expired"})
push(socket, "disconnect", %{reason: :token_expired})
send(socket.transport_pid, %Phoenix.Socket.Broadcast{event: "disconnect"})
{:stop, :shutdown, socket}
end
@@ -285,7 +285,7 @@ defmodule API.Client.Channel do
OpenTelemetry.Tracer.with_span "client.create_log_sink" do
case Instrumentation.create_remote_log_sink(socket.assigns.client, actor_name, account_slug) do
{:ok, signed_url} -> {:reply, {:ok, signed_url}, socket}
{:error, :disabled} -> {:reply, {:error, :disabled}, socket}
{:error, :disabled} -> {:reply, {:error, %{reason: :disabled}}, socket}
end
end
end
@@ -343,11 +343,11 @@ defmodule API.Client.Channel do
else
{:ok, []} ->
OpenTelemetry.Tracer.set_status(:error, "offline")
{:reply, {:error, :offline}, socket}
{:reply, {:error, %{reason: :offline}}, socket}
{:error, :not_found} ->
OpenTelemetry.Tracer.set_status(:error, "not_found")
{:reply, {:error, :not_found}, socket}
{:reply, {:error, %{reason: :not_found}}, socket}
end
end
end
@@ -396,11 +396,11 @@ defmodule API.Client.Channel do
else
{:error, :not_found} ->
OpenTelemetry.Tracer.set_status(:error, "not_found")
{:reply, {:error, :not_found}, socket}
{:reply, {:error, %{reason: :not_found}}, socket}
false ->
OpenTelemetry.Tracer.set_status(:error, "offline")
{:reply, {:error, :offline}, socket}
{:reply, {:error, %{reason: :offline}}, socket}
end
end
end
@@ -452,11 +452,11 @@ defmodule API.Client.Channel do
else
{:error, :not_found} ->
OpenTelemetry.Tracer.set_status(:error, "not_found")
{:reply, {:error, :not_found}, socket}
{:reply, {:error, %{reason: :not_found}}, socket}
false ->
OpenTelemetry.Tracer.set_status(:error, "offline")
{:reply, {:error, :offline}, socket}
{:reply, {:error, %{reason: :offline}}, socket}
end
end
end

View File

@@ -35,14 +35,15 @@ defmodule API.Gateway.Channel do
:ok = Gateways.connect_gateway(socket.assigns.gateway)
config = Domain.Config.fetch_env!(:domain, Domain.Gateways)
ipv4_masquerade_enabled = Keyword.fetch!(config, :gateway_ipv4_masquerade)
ipv6_masquerade_enabled = Keyword.fetch!(config, :gateway_ipv6_masquerade)
ipv4_masquerade_enabled? = Keyword.fetch!(config, :gateway_ipv4_masquerade)
ipv6_masquerade_enabled? = Keyword.fetch!(config, :gateway_ipv6_masquerade)
push(socket, "init", %{
interface: Views.Interface.render(socket.assigns.gateway),
# TODO: move to settings
ipv4_masquerade_enabled: ipv4_masquerade_enabled,
ipv6_masquerade_enabled: ipv6_masquerade_enabled
config: %{
ipv4_masquerade_enabled: ipv4_masquerade_enabled?,
ipv6_masquerade_enabled: ipv6_masquerade_enabled?
}
})
{:noreply, socket}
@@ -259,8 +260,15 @@ defmodule API.Gateway.Channel do
},
socket
) do
{{channel_pid, socket_ref, resource_id, {opentelemetry_ctx, opentelemetry_span_ctx}}, refs} =
Map.pop(socket.assigns.refs, ref)
{
{
channel_pid,
socket_ref,
resource_id,
{opentelemetry_ctx, opentelemetry_span_ctx}
},
refs
} = Map.pop(socket.assigns.refs, ref)
OpenTelemetry.Ctx.attach(opentelemetry_ctx)
OpenTelemetry.Tracer.set_current_span(opentelemetry_span_ctx)

View File

@@ -131,7 +131,7 @@ defmodule API.Client.ChannelTest do
})
|> subscribe_and_join(API.Client.Channel, "client")
assert_push "disconnect", %{"reason" => "token_expired"}, 250
assert_push "disconnect", %{reason: :token_expired}, 250
assert_receive {:EXIT, _pid, {:shutdown, :token_expired}}
assert_receive {:socket_close, _pid, {:shutdown, :token_expired}}
end
@@ -207,7 +207,7 @@ defmodule API.Client.ChannelTest do
assert_push "init", %{}
Process.flag(:trap_exit, true)
Domain.Clients.broadcast_to_client(client, :token_expired)
assert_push "disconnect", %{"reason" => "token_expired"}, 250
assert_push "disconnect", %{reason: :token_expired}, 250
end
test "subscribes for resource events", %{
@@ -269,7 +269,7 @@ defmodule API.Client.ChannelTest do
channel_pid = socket.channel_pid
send(channel_pid, :token_expired)
assert_push "disconnect", %{"reason" => "token_expired"}
assert_push "disconnect", %{reason: :token_expired}
assert_receive {:EXIT, ^channel_pid, {:shutdown, :token_expired}}
end
@@ -470,7 +470,7 @@ defmodule API.Client.ChannelTest do
Domain.Config.put_env_override(Domain.Instrumentation, client_logs_enabled: false)
ref = push(socket, "create_log_sink", %{})
assert_reply ref, :error, :disabled
assert_reply ref, :error, %{reason: :disabled}
end
test "returns a signed URL which can be used to upload the logs", %{
@@ -509,7 +509,7 @@ defmodule API.Client.ChannelTest do
describe "handle_in/3 prepare_connection" do
test "returns error when resource is not found", %{socket: socket} do
ref = push(socket, "prepare_connection", %{"resource_id" => Ecto.UUID.generate()})
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when there are no online relays", %{
@@ -517,7 +517,7 @@ defmodule API.Client.ChannelTest do
socket: socket
} do
ref = push(socket, "prepare_connection", %{"resource_id" => resource.id})
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "returns error when all gateways are offline", %{
@@ -525,7 +525,7 @@ defmodule API.Client.ChannelTest do
socket: socket
} do
ref = push(socket, "prepare_connection", %{"resource_id" => resource.id})
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "returns error when client has no policy allowing access to resource", %{
@@ -542,7 +542,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "prepare_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when all gateways connected to the resource are offline", %{
@@ -554,7 +554,7 @@ defmodule API.Client.ChannelTest do
:ok = Domain.Gateways.connect_gateway(gateway)
ref = push(socket, "prepare_connection", %{"resource_id" => resource.id})
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "returns online gateway and relays connected to the resource", %{
@@ -836,7 +836,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "reuse_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is not found", %{dns_resource: resource, socket: socket} do
@@ -847,7 +847,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "reuse_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is not connected to resource", %{
@@ -865,7 +865,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "reuse_connection", attrs)
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "returns error when client has no policy allowing access to resource", %{
@@ -884,7 +884,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "reuse_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is offline", %{
@@ -899,7 +899,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "reuse_connection", attrs)
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "broadcasts allow_access to the gateways and then returns connect message", %{
@@ -995,7 +995,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "request_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is not found", %{dns_resource: resource, socket: socket} do
@@ -1007,7 +1007,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "request_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is not connected to resource", %{
@@ -1026,7 +1026,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "request_connection", attrs)
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "returns error when client has no policy allowing access to resource", %{
@@ -1046,7 +1046,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "request_connection", attrs)
assert_reply ref, :error, :not_found
assert_reply ref, :error, %{reason: :not_found}
end
test "returns error when gateway is offline", %{
@@ -1062,7 +1062,7 @@ defmodule API.Client.ChannelTest do
}
ref = push(socket, "request_connection", attrs)
assert_reply ref, :error, :offline
assert_reply ref, :error, %{reason: :offline}
end
test "broadcasts request_connection to the gateways and then returns connect message", %{

View File

@@ -57,8 +57,10 @@ defmodule API.Gateway.ChannelTest do
} do
assert_push "init", %{
interface: interface,
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true
config: %{
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true
}
}
assert interface == %{

View File

@@ -1,8 +1,6 @@
use async_compression::tokio::bufread::GzipEncoder;
use bimap::BiMap;
use connlib_shared::control::ChannelError;
use connlib_shared::control::KnownError;
use connlib_shared::control::Reason;
use connlib_shared::control::{ChannelError, ErrorReply};
use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer};
use connlib_shared::IpProvider;
use firezone_tunnel::ClientTunnel;
@@ -17,7 +15,7 @@ use crate::messages::{
GatewayIceCandidates, InitClient, Messages,
};
use connlib_shared::{
control::{ErrorInfo, PhoenixSenderWithTopic, Reference},
control::{PhoenixSenderWithTopic, Reference},
messages::{GatewayId, ResourceDescription, ResourceId},
Callbacks,
Error::{self},
@@ -334,7 +332,7 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
topic: String,
) -> Result<()> {
match (reply_error, reference) {
(ChannelError::ErrorReply(ErrorInfo::Offline), Some(reference)) => {
(ChannelError::ErrorReply(ErrorReply::Offline), Some(reference)) => {
let Ok(request_id) = reference.parse::<usize>() else {
return Ok(());
};
@@ -347,22 +345,12 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
self.tunnel.cleanup_connection(resource);
}
(
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
KnownError::UnmatchedTopic,
))),
_,
) => {
(ChannelError::ErrorReply(ErrorReply::UnmatchedTopic), _) => {
if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await {
tracing::debug!(err = ?e, "couldn't join topic: {e:#?}");
}
}
(
ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known(
KnownError::TokenExpired,
))),
_,
)
(ChannelError::ErrorReply(ErrorReply::TokenExpired), _)
| (ChannelError::ErrorMsg(Error::ClosedByPortal), _) => {
return Err(Error::ClosedByPortal);
}

View File

@@ -161,7 +161,6 @@ mod test {
};
use chrono::NaiveDateTime;
use connlib_shared::control::ErrorInfo;
use crate::messages::{ConnectionDetails, EgressMessages, ReplyMessages};
@@ -403,13 +402,13 @@ mod test {
#[test]
fn create_log_sink_error_response() {
let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"error","response":"disabled"}}"#;
let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#;
let actual =
serde_json::from_str::<PhoenixMessage<EgressMessages, ReplyMessages>>(json).unwrap();
let expected = PhoenixMessage::new_err_reply(
"client",
ErrorInfo::Disabled,
connlib_shared::control::ErrorReply::Disabled,
"unique_log_sink_ref".to_owned(),
);

View File

@@ -217,28 +217,24 @@ where
handler(Ok(payload.into()), m.reference, m.topic).await
}
Payload::Reply(status) => match status {
ReplyMessage::PhxReply(phx_reply) => match phx_reply {
// TODO: Here we should pass error info to a subscriber
PhxReply::Error(info) => {
tracing::debug!("Portal error: {info:?}");
handler(Err(ChannelError::ErrorReply(info)), m.reference, m.topic)
.await
// TODO: Here we should pass error info to a subscriber
Reply::Error { reason } => {
tracing::debug!("Portal error: {reason:?}");
handler(Err(ChannelError::ErrorReply(reason)), m.reference, m.topic)
.await
}
Reply::Ok(reply) => match reply {
OkReply::NoMessage(Empty {}) => {
tracing::trace!(target: "phoenix_status", "Phoenix status message")
}
OkReply::Message(payload) => {
handler(Ok(payload.into()), m.reference, m.topic).await
}
PhxReply::Ok(reply) => match reply {
OkReply::NoMessage(Empty {}) => {
tracing::trace!(target: "phoenix_status", "Phoenix status message")
}
OkReply::Message(payload) => {
handler(Ok(payload.into()), m.reference, m.topic).await
}
},
},
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
},
Payload::ControlMessage(ControlMessage::PhxClose(_)) => {
return Err(Error::ClosedByPortal)
}
Payload::ControlMessage(ControlMessage::Disconnect { reason: _reason }) => {
Payload::Error(_) => {}
Payload::Close(Empty {}) => return Err(Error::ClosedByPortal),
Payload::Disconnect { reason: _reason } => {
// TODO: pass the _reason up to the client so it can print a pertinent user message
handler(
Err(ChannelError::ErrorMsg(Error::ClosedByPortal)),
@@ -309,28 +305,10 @@ pub type MessageResult<M> = std::result::Result<M, ChannelError>;
#[derive(Debug)]
pub enum ChannelError {
ErrorReply(ErrorInfo),
ErrorReply(ErrorReply),
ErrorMsg(Error),
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum Payload<T, R> {
// We might want other type for the reply message
// but that makes everything even more convoluted!
// and we need to think how to make this whole mess less convoluted.
Reply(ReplyMessage<R>),
ControlMessage(ControlMessage),
Message(T),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ControlMessage {
PhxClose(Empty),
Disconnect { reason: String },
}
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
// TODO: we should use a newtype pattern for topics
@@ -341,6 +319,54 @@ pub struct PhoenixMessage<T, R> {
reference: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(tag = "event", content = "payload")]
enum Payload<T, R> {
#[serde(rename = "phx_reply")]
Reply(Reply<R>),
#[serde(rename = "phx_error")]
Error(Empty),
#[serde(rename = "phx_close")]
Close(Empty),
#[serde(rename = "disconnect")]
Disconnect { reason: String },
#[serde(untagged)]
Message(T),
}
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum Reply<T> {
Ok(OkReply<T>),
Error { reason: ErrorReply },
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
enum OkReply<T> {
Message(T),
NoMessage(Empty),
}
/// This represents the info we have about the error
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ErrorReply {
#[serde(rename = "unmatched topic")]
UnmatchedTopic,
TokenExpired,
NotFound,
Offline,
Disabled,
#[serde(other)]
Other,
}
impl<T, R> PhoenixMessage<T, R> {
pub fn new(topic: impl Into<String>, payload: T, reference: Option<String>) -> Self {
Self {
@@ -358,32 +384,25 @@ impl<T, R> PhoenixMessage<T, R> {
Self {
topic: topic.into(),
// There has to be a better way :\
payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message(
payload,
)))),
payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))),
reference: reference.into(),
}
}
pub fn new_err_reply(
topic: impl Into<String>,
error: ErrorInfo,
reason: ErrorReply,
reference: impl Into<Option<String>>,
) -> Self {
Self {
topic: topic.into(),
// There has to be a better way :\
payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(error))),
payload: Payload::Reply(Reply::Error { reason }),
reference: reference.into(),
}
}
}
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum EgressControlMessage {
@@ -391,54 +410,6 @@ enum EgressControlMessage {
Heartbeat(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ReplyMessage<T> {
PhxReply(PhxReply<T>),
PhxError(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
enum OkReply<T> {
Message(T),
NoMessage(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct UnknownError(pub String);
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub enum KnownError {
#[serde(rename = "unmatched topic")]
UnmatchedTopic,
#[serde(rename = "token_expired")]
TokenExpired,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum Reason {
Known(KnownError),
Unknown(UnknownError),
}
/// This represents the info we have about the error
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ErrorInfo {
Reason(Reason),
Offline,
Disabled,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum PhxReply<T> {
Ok(OkReply<T>),
Error(ErrorInfo),
}
/// You can use this sender to send messages through a `PhoenixChannel`.
///
/// Messages won't be sent unless [PhoenixChannel::start] is running, internally
@@ -539,30 +510,27 @@ impl PhoenixSender {
#[cfg(test)]
mod tests {
use crate::control::{
ControlMessage, Empty, ErrorInfo, KnownError, Payload, PhxReply::Error, Reason,
ReplyMessage::PhxReply, UnknownError,
};
use super::*;
#[test]
fn unmatched_topic_reply() {
let actual_reply = r#"
{
"event":"phx_reply",
"ref":"12",
"topic":"client",
"event": "phx_reply",
"ref": "12",
"topic": "client",
"payload":{
"status":"error",
"status": "error",
"response":{
"reason":"unmatched topic"
"reason": "unmatched topic"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason(
Reason::Known(KnownError::UnmatchedTopic),
))));
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::UnmatchedTopic,
});
assert_eq!(actual_reply, expected_reply);
}
@@ -577,7 +545,7 @@ mod tests {
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::PhxClose(Empty {}));
let expected_reply = Payload::<(), ()>::Close(Empty {});
assert_eq!(actual_reply, expected_reply);
}
@@ -592,8 +560,30 @@ mod tests {
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::Disconnect {
let expected_reply = Payload::<(), ()>::Disconnect {
reason: "token_expired".to_string(),
};
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn not_found() {
let actual_reply = r#"
{
"event": "phx_reply",
"ref": null,
"topic": "client",
"payload": {
"status": "error",
"response": {
"reason": "not_found"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::NotFound,
});
assert_eq!(actual_reply, expected_reply);
}
@@ -602,21 +592,21 @@ mod tests {
fn unexpected_error_reply() {
let actual_reply = r#"
{
"event":"phx_reply",
"ref":"12",
"topic":"client",
"payload":{
"status":"error",
"response":{
"reason":"bad reply"
"event": "phx_reply",
"ref": "12",
"topic": "client",
"payload": {
"status": "error",
"response": {
"reason": "bad reply"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason(
Reason::Unknown(UnknownError("bad reply".to_string())),
))));
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::Other,
});
assert_eq!(actual_reply, expected_reply);
}
}

View File

@@ -13,8 +13,7 @@ use std::net::IpAddr;
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
pub struct InitGateway {
pub interface: Interface,
pub ipv4_masquerade_enabled: bool,
pub ipv6_masquerade_enabled: bool,
pub config: Config,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
@@ -22,6 +21,12 @@ pub struct Actor {
pub id: ActorId,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Config {
pub ipv4_masquerade_enabled: bool,
pub ipv6_masquerade_enabled: bool,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Client {
pub id: ClientId,
@@ -141,11 +146,10 @@ pub struct ConnectionReady {
#[cfg(test)]
mod test {
use connlib_shared::{control::PhoenixMessage, messages::Interface};
use super::*;
use connlib_shared::control::PhoenixMessage;
use phoenix_channel::InitMessage;
use super::{IngressMessages, InitGateway};
#[test]
fn request_connection_message() {
let message = r#"{
@@ -206,11 +210,13 @@ mod test {
ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(),
upstream_dns: vec![],
},
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true,
config: Config {
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true,
},
});
let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}"#;
let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"config":{"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}}"#;
let ingress_message = serde_json::from_str::<InitMessage<InitGateway>>(message).unwrap();
assert_eq!(m, ingress_message);
}

View File

@@ -344,9 +344,7 @@ where
}))
}
},
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(
ErrorInfo::Reason(reason),
))) => {
Payload::Reply(Reply::Error { reason }) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
@@ -355,9 +353,7 @@ where
reason,
}));
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message(
reply,
)))) => {
Payload::Reply(Reply::Ok(OkReply::Message(reply))) => {
let req_id =
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
@@ -376,18 +372,7 @@ where
res: reply,
}));
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(
ErrorInfo::Offline,
))) => {
tracing::warn!(
"Received offline error for request {:?}",
message.reference
);
continue;
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(
OkReply::NoMessage(Empty {}),
))) => {
Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => {
let id =
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
@@ -402,20 +387,20 @@ where
continue;
}
Payload::Reply(ReplyMessage::PhxError(Empty {})) => {
Payload::Error(Empty {}) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
message.reference.ok_or(Error::MissingReplyId)?,
),
reason: "unknown error (bad event?)".to_owned(),
reason: ErrorReply::Other,
}))
}
Payload::ControlMessage(ControlMessage::PhxClose(_)) => {
Payload::Close(Empty {}) => {
self.reconnect_on_transient_error(Error::CloseMessage);
continue;
}
Payload::ControlMessage(ControlMessage::Disconnect { reason }) => {
Payload::Disconnect { reason } => {
return Poll::Ready(Ok(Event::Disconnect(reason)));
}
}
@@ -523,7 +508,7 @@ pub enum Event<TInboundMsg, TOutboundRes> {
ErrorResponse {
topic: String,
req_id: OutboundRequestId,
reason: String,
reason: ErrorReply,
},
/// The server sent us a message, most likely this is a broadcast to all connected clients.
InboundMessage {
@@ -538,26 +523,9 @@ pub enum Event<TInboundMsg, TOutboundRes> {
Disconnect(String),
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum Payload<T, R> {
// We might want other type for the reply message
// but that makes everything even more convoluted!
// and we need to think how to make this whole mess less convoluted.
Reply(ReplyMessage<R>),
ControlMessage(ControlMessage),
Message(T),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ControlMessage {
PhxClose(Empty),
Disconnect { reason: String },
}
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
// TODO: we should use a newtype pattern for topics
topic: String,
#[serde(flatten)]
payload: Payload<T, R>,
@@ -565,6 +533,54 @@ pub struct PhoenixMessage<T, R> {
reference: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(tag = "event", content = "payload")]
enum Payload<T, R> {
#[serde(rename = "phx_reply")]
Reply(Reply<R>),
#[serde(rename = "phx_error")]
Error(Empty),
#[serde(rename = "phx_close")]
Close(Empty),
#[serde(rename = "disconnect")]
Disconnect { reason: String },
#[serde(untagged)]
Message(T),
}
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum Reply<T> {
Ok(OkReply<T>),
Error { reason: ErrorReply },
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
enum OkReply<T> {
Message(T),
NoMessage(Empty),
}
/// This represents the info we have about the error
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ErrorReply {
#[serde(rename = "unmatched topic")]
UnmatchedTopic,
TokenExpired,
NotFound,
Offline,
Disabled,
#[serde(other)]
Other,
}
impl<T, R> PhoenixMessage<T, R> {
pub fn new(topic: impl Into<String>, payload: T, reference: u64) -> Self {
Self {
@@ -609,11 +625,6 @@ fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Result<Req
Ok(req)
}
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum EgressControlMessage<T> {
@@ -621,34 +632,6 @@ enum EgressControlMessage<T> {
Heartbeat(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ReplyMessage<T> {
PhxReply(PhxReply<T>),
PhxError(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
enum OkReply<T> {
Message(T),
NoMessage(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum ErrorInfo {
Reason(String),
Offline,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum PhxReply<T> {
Ok(OkReply<T>),
Error(ErrorInfo),
}
#[cfg(test)]
mod tests {
use super::*;
@@ -662,14 +645,14 @@ mod tests {
#[test]
fn can_deserialize_inbound_message() {
let msg = r#"{
"topic": "room:lobby",
"ref": null,
"payload": {
"hello": "world"
},
"join_ref": null,
"event": "shout"
}"#;
"topic": "room:lobby",
"ref": null,
"payload": {
"hello": "world"
},
"join_ref": null,
"event": "shout"
}"#;
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
@@ -698,4 +681,102 @@ mod tests {
Payload::Message(InitMessage::Init(EmptyInit {}))
);
}
#[test]
fn unmatched_topic_reply() {
let actual_reply = r#"
{
"event": "phx_reply",
"ref": "12",
"topic": "client",
"payload":{
"status": "error",
"response":{
"reason": "unmatched topic"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::UnmatchedTopic,
});
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn phx_close() {
let actual_reply = r#"
{
"event": "phx_close",
"ref": null,
"topic": "client",
"payload": {}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Close(Empty {});
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn token_expired() {
let actual_reply = r#"
{
"event": "disconnect",
"ref": null,
"topic": "client",
"payload": { "reason": "token_expired" }
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Disconnect {
reason: "token_expired".to_string(),
};
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn not_found() {
let actual_reply = r#"
{
"event": "phx_reply",
"ref": null,
"topic": "client",
"payload": {
"status": "error",
"response": {
"reason": "not_found"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::NotFound,
});
assert_eq!(actual_reply, expected_reply);
}
#[test]
fn unexpected_error_reply() {
let actual_reply = r#"
{
"event": "phx_reply",
"ref": "12",
"topic": "client",
"payload": {
"status": "error",
"response": {
"reason": "bad reply"
}
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
reason: ErrorReply::Other,
});
assert_eq!(actual_reply, expected_reply);
}
}

View File

@@ -516,7 +516,7 @@ where
req_id,
reason,
}))) => {
tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {reason}");
tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {reason:?}");
continue;
}
Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {