diff --git a/elixir/apps/api/lib/api/client/channel.ex b/elixir/apps/api/lib/api/client/channel.ex index 7708bc756..77505ab0f 100644 --- a/elixir/apps/api/lib/api/client/channel.ex +++ b/elixir/apps/api/lib/api/client/channel.ex @@ -39,7 +39,7 @@ defmodule API.Client.Channel do Process.send_after(self(), :token_expired, expires_in) {:ok, socket} else - {:error, %{"reason" => "token_expired"}} + {:error, %{reason: :token_expired}} end end @@ -107,7 +107,7 @@ defmodule API.Client.Channel do OpenTelemetry.Tracer.set_current_span(socket.assigns.opentelemetry_span_ctx) OpenTelemetry.Tracer.with_span "client.token_expired" do - push(socket, "disconnect", %{"reason" => "token_expired"}) + push(socket, "disconnect", %{reason: :token_expired}) {:stop, {:shutdown, :token_expired}, socket} end end @@ -115,7 +115,7 @@ defmodule API.Client.Channel do # This message is sent using Clients.broadcast_to_client/1 eg. when the client is deleted def handle_info("disconnect", socket) do OpenTelemetry.Tracer.with_span "client.disconnect" do - push(socket, "disconnect", %{"reason" => "token_expired"}) + push(socket, "disconnect", %{reason: :token_expired}) send(socket.transport_pid, %Phoenix.Socket.Broadcast{event: "disconnect"}) {:stop, :shutdown, socket} end @@ -285,7 +285,7 @@ defmodule API.Client.Channel do OpenTelemetry.Tracer.with_span "client.create_log_sink" do case Instrumentation.create_remote_log_sink(socket.assigns.client, actor_name, account_slug) do {:ok, signed_url} -> {:reply, {:ok, signed_url}, socket} - {:error, :disabled} -> {:reply, {:error, :disabled}, socket} + {:error, :disabled} -> {:reply, {:error, %{reason: :disabled}}, socket} end end end @@ -343,11 +343,11 @@ defmodule API.Client.Channel do else {:ok, []} -> OpenTelemetry.Tracer.set_status(:error, "offline") - {:reply, {:error, :offline}, socket} + {:reply, {:error, %{reason: :offline}}, socket} {:error, :not_found} -> OpenTelemetry.Tracer.set_status(:error, "not_found") - {:reply, {:error, :not_found}, socket} + {:reply, {:error, %{reason: :not_found}}, socket} end end end @@ -396,11 +396,11 @@ defmodule API.Client.Channel do else {:error, :not_found} -> OpenTelemetry.Tracer.set_status(:error, "not_found") - {:reply, {:error, :not_found}, socket} + {:reply, {:error, %{reason: :not_found}}, socket} false -> OpenTelemetry.Tracer.set_status(:error, "offline") - {:reply, {:error, :offline}, socket} + {:reply, {:error, %{reason: :offline}}, socket} end end end @@ -452,11 +452,11 @@ defmodule API.Client.Channel do else {:error, :not_found} -> OpenTelemetry.Tracer.set_status(:error, "not_found") - {:reply, {:error, :not_found}, socket} + {:reply, {:error, %{reason: :not_found}}, socket} false -> OpenTelemetry.Tracer.set_status(:error, "offline") - {:reply, {:error, :offline}, socket} + {:reply, {:error, %{reason: :offline}}, socket} end end end diff --git a/elixir/apps/api/lib/api/gateway/channel.ex b/elixir/apps/api/lib/api/gateway/channel.ex index c34a45766..8d501df72 100644 --- a/elixir/apps/api/lib/api/gateway/channel.ex +++ b/elixir/apps/api/lib/api/gateway/channel.ex @@ -35,14 +35,15 @@ defmodule API.Gateway.Channel do :ok = Gateways.connect_gateway(socket.assigns.gateway) config = Domain.Config.fetch_env!(:domain, Domain.Gateways) - ipv4_masquerade_enabled = Keyword.fetch!(config, :gateway_ipv4_masquerade) - ipv6_masquerade_enabled = Keyword.fetch!(config, :gateway_ipv6_masquerade) + ipv4_masquerade_enabled? = Keyword.fetch!(config, :gateway_ipv4_masquerade) + ipv6_masquerade_enabled? = Keyword.fetch!(config, :gateway_ipv6_masquerade) push(socket, "init", %{ interface: Views.Interface.render(socket.assigns.gateway), - # TODO: move to settings - ipv4_masquerade_enabled: ipv4_masquerade_enabled, - ipv6_masquerade_enabled: ipv6_masquerade_enabled + config: %{ + ipv4_masquerade_enabled: ipv4_masquerade_enabled?, + ipv6_masquerade_enabled: ipv6_masquerade_enabled? + } }) {:noreply, socket} @@ -259,8 +260,15 @@ defmodule API.Gateway.Channel do }, socket ) do - {{channel_pid, socket_ref, resource_id, {opentelemetry_ctx, opentelemetry_span_ctx}}, refs} = - Map.pop(socket.assigns.refs, ref) + { + { + channel_pid, + socket_ref, + resource_id, + {opentelemetry_ctx, opentelemetry_span_ctx} + }, + refs + } = Map.pop(socket.assigns.refs, ref) OpenTelemetry.Ctx.attach(opentelemetry_ctx) OpenTelemetry.Tracer.set_current_span(opentelemetry_span_ctx) diff --git a/elixir/apps/api/test/api/client/channel_test.exs b/elixir/apps/api/test/api/client/channel_test.exs index 5c8917a45..cb765048f 100644 --- a/elixir/apps/api/test/api/client/channel_test.exs +++ b/elixir/apps/api/test/api/client/channel_test.exs @@ -131,7 +131,7 @@ defmodule API.Client.ChannelTest do }) |> subscribe_and_join(API.Client.Channel, "client") - assert_push "disconnect", %{"reason" => "token_expired"}, 250 + assert_push "disconnect", %{reason: :token_expired}, 250 assert_receive {:EXIT, _pid, {:shutdown, :token_expired}} assert_receive {:socket_close, _pid, {:shutdown, :token_expired}} end @@ -207,7 +207,7 @@ defmodule API.Client.ChannelTest do assert_push "init", %{} Process.flag(:trap_exit, true) Domain.Clients.broadcast_to_client(client, :token_expired) - assert_push "disconnect", %{"reason" => "token_expired"}, 250 + assert_push "disconnect", %{reason: :token_expired}, 250 end test "subscribes for resource events", %{ @@ -269,7 +269,7 @@ defmodule API.Client.ChannelTest do channel_pid = socket.channel_pid send(channel_pid, :token_expired) - assert_push "disconnect", %{"reason" => "token_expired"} + assert_push "disconnect", %{reason: :token_expired} assert_receive {:EXIT, ^channel_pid, {:shutdown, :token_expired}} end @@ -470,7 +470,7 @@ defmodule API.Client.ChannelTest do Domain.Config.put_env_override(Domain.Instrumentation, client_logs_enabled: false) ref = push(socket, "create_log_sink", %{}) - assert_reply ref, :error, :disabled + assert_reply ref, :error, %{reason: :disabled} end test "returns a signed URL which can be used to upload the logs", %{ @@ -509,7 +509,7 @@ defmodule API.Client.ChannelTest do describe "handle_in/3 prepare_connection" do test "returns error when resource is not found", %{socket: socket} do ref = push(socket, "prepare_connection", %{"resource_id" => Ecto.UUID.generate()}) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when there are no online relays", %{ @@ -517,7 +517,7 @@ defmodule API.Client.ChannelTest do socket: socket } do ref = push(socket, "prepare_connection", %{"resource_id" => resource.id}) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "returns error when all gateways are offline", %{ @@ -525,7 +525,7 @@ defmodule API.Client.ChannelTest do socket: socket } do ref = push(socket, "prepare_connection", %{"resource_id" => resource.id}) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "returns error when client has no policy allowing access to resource", %{ @@ -542,7 +542,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "prepare_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when all gateways connected to the resource are offline", %{ @@ -554,7 +554,7 @@ defmodule API.Client.ChannelTest do :ok = Domain.Gateways.connect_gateway(gateway) ref = push(socket, "prepare_connection", %{"resource_id" => resource.id}) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "returns online gateway and relays connected to the resource", %{ @@ -836,7 +836,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "reuse_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is not found", %{dns_resource: resource, socket: socket} do @@ -847,7 +847,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "reuse_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is not connected to resource", %{ @@ -865,7 +865,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "reuse_connection", attrs) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "returns error when client has no policy allowing access to resource", %{ @@ -884,7 +884,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "reuse_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is offline", %{ @@ -899,7 +899,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "reuse_connection", attrs) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "broadcasts allow_access to the gateways and then returns connect message", %{ @@ -995,7 +995,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "request_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is not found", %{dns_resource: resource, socket: socket} do @@ -1007,7 +1007,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "request_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is not connected to resource", %{ @@ -1026,7 +1026,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "request_connection", attrs) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "returns error when client has no policy allowing access to resource", %{ @@ -1046,7 +1046,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "request_connection", attrs) - assert_reply ref, :error, :not_found + assert_reply ref, :error, %{reason: :not_found} end test "returns error when gateway is offline", %{ @@ -1062,7 +1062,7 @@ defmodule API.Client.ChannelTest do } ref = push(socket, "request_connection", attrs) - assert_reply ref, :error, :offline + assert_reply ref, :error, %{reason: :offline} end test "broadcasts request_connection to the gateways and then returns connect message", %{ diff --git a/elixir/apps/api/test/api/gateway/channel_test.exs b/elixir/apps/api/test/api/gateway/channel_test.exs index 93310746d..ea78000ed 100644 --- a/elixir/apps/api/test/api/gateway/channel_test.exs +++ b/elixir/apps/api/test/api/gateway/channel_test.exs @@ -57,8 +57,10 @@ defmodule API.Gateway.ChannelTest do } do assert_push "init", %{ interface: interface, - ipv4_masquerade_enabled: true, - ipv6_masquerade_enabled: true + config: %{ + ipv4_masquerade_enabled: true, + ipv6_masquerade_enabled: true + } } assert interface == %{ diff --git a/rust/connlib/clients/shared/src/control.rs b/rust/connlib/clients/shared/src/control.rs index e36d6bdab..83a9a70dd 100644 --- a/rust/connlib/clients/shared/src/control.rs +++ b/rust/connlib/clients/shared/src/control.rs @@ -1,8 +1,6 @@ use async_compression::tokio::bufread::GzipEncoder; use bimap::BiMap; -use connlib_shared::control::ChannelError; -use connlib_shared::control::KnownError; -use connlib_shared::control::Reason; +use connlib_shared::control::{ChannelError, ErrorReply}; use connlib_shared::messages::{DnsServer, GatewayResponse, IpDnsServer}; use connlib_shared::IpProvider; use firezone_tunnel::ClientTunnel; @@ -17,7 +15,7 @@ use crate::messages::{ GatewayIceCandidates, InitClient, Messages, }; use connlib_shared::{ - control::{ErrorInfo, PhoenixSenderWithTopic, Reference}, + control::{PhoenixSenderWithTopic, Reference}, messages::{GatewayId, ResourceDescription, ResourceId}, Callbacks, Error::{self}, @@ -334,7 +332,7 @@ impl ControlPlane { topic: String, ) -> Result<()> { match (reply_error, reference) { - (ChannelError::ErrorReply(ErrorInfo::Offline), Some(reference)) => { + (ChannelError::ErrorReply(ErrorReply::Offline), Some(reference)) => { let Ok(request_id) = reference.parse::() else { return Ok(()); }; @@ -347,22 +345,12 @@ impl ControlPlane { self.tunnel.cleanup_connection(resource); } - ( - ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known( - KnownError::UnmatchedTopic, - ))), - _, - ) => { + (ChannelError::ErrorReply(ErrorReply::UnmatchedTopic), _) => { if let Err(e) = self.phoenix_channel.get_sender().join_topic(topic).await { tracing::debug!(err = ?e, "couldn't join topic: {e:#?}"); } } - ( - ChannelError::ErrorReply(ErrorInfo::Reason(Reason::Known( - KnownError::TokenExpired, - ))), - _, - ) + (ChannelError::ErrorReply(ErrorReply::TokenExpired), _) | (ChannelError::ErrorMsg(Error::ClosedByPortal), _) => { return Err(Error::ClosedByPortal); } diff --git a/rust/connlib/clients/shared/src/messages.rs b/rust/connlib/clients/shared/src/messages.rs index af08b7aa8..b179961db 100644 --- a/rust/connlib/clients/shared/src/messages.rs +++ b/rust/connlib/clients/shared/src/messages.rs @@ -161,7 +161,6 @@ mod test { }; use chrono::NaiveDateTime; - use connlib_shared::control::ErrorInfo; use crate::messages::{ConnectionDetails, EgressMessages, ReplyMessages}; @@ -403,13 +402,13 @@ mod test { #[test] fn create_log_sink_error_response() { - let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"error","response":"disabled"}}"#; + let json = r#"{"event":"phx_reply","ref":"unique_log_sink_ref","topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#; let actual = serde_json::from_str::>(json).unwrap(); let expected = PhoenixMessage::new_err_reply( "client", - ErrorInfo::Disabled, + connlib_shared::control::ErrorReply::Disabled, "unique_log_sink_ref".to_owned(), ); diff --git a/rust/connlib/shared/src/control.rs b/rust/connlib/shared/src/control.rs index c18378d1c..2789b095c 100644 --- a/rust/connlib/shared/src/control.rs +++ b/rust/connlib/shared/src/control.rs @@ -217,28 +217,24 @@ where handler(Ok(payload.into()), m.reference, m.topic).await } Payload::Reply(status) => match status { - ReplyMessage::PhxReply(phx_reply) => match phx_reply { - // TODO: Here we should pass error info to a subscriber - PhxReply::Error(info) => { - tracing::debug!("Portal error: {info:?}"); - handler(Err(ChannelError::ErrorReply(info)), m.reference, m.topic) - .await + // TODO: Here we should pass error info to a subscriber + Reply::Error { reason } => { + tracing::debug!("Portal error: {reason:?}"); + handler(Err(ChannelError::ErrorReply(reason)), m.reference, m.topic) + .await + } + Reply::Ok(reply) => match reply { + OkReply::NoMessage(Empty {}) => { + tracing::trace!(target: "phoenix_status", "Phoenix status message") + } + OkReply::Message(payload) => { + handler(Ok(payload.into()), m.reference, m.topic).await } - PhxReply::Ok(reply) => match reply { - OkReply::NoMessage(Empty {}) => { - tracing::trace!(target: "phoenix_status", "Phoenix status message") - } - OkReply::Message(payload) => { - handler(Ok(payload.into()), m.reference, m.topic).await - } - }, }, - ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"), }, - Payload::ControlMessage(ControlMessage::PhxClose(_)) => { - return Err(Error::ClosedByPortal) - } - Payload::ControlMessage(ControlMessage::Disconnect { reason: _reason }) => { + Payload::Error(_) => {} + Payload::Close(Empty {}) => return Err(Error::ClosedByPortal), + Payload::Disconnect { reason: _reason } => { // TODO: pass the _reason up to the client so it can print a pertinent user message handler( Err(ChannelError::ErrorMsg(Error::ClosedByPortal)), @@ -309,28 +305,10 @@ pub type MessageResult = std::result::Result; #[derive(Debug)] pub enum ChannelError { - ErrorReply(ErrorInfo), + ErrorReply(ErrorReply), ErrorMsg(Error), } -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] -#[serde(untagged)] -enum Payload { - // We might want other type for the reply message - // but that makes everything even more convoluted! - // and we need to think how to make this whole mess less convoluted. - Reply(ReplyMessage), - ControlMessage(ControlMessage), - Message(T), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum ControlMessage { - PhxClose(Empty), - Disconnect { reason: String }, -} - #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics @@ -341,6 +319,54 @@ pub struct PhoenixMessage { reference: Option, } +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[serde(tag = "event", content = "payload")] +enum Payload { + #[serde(rename = "phx_reply")] + Reply(Reply), + #[serde(rename = "phx_error")] + Error(Empty), + #[serde(rename = "phx_close")] + Close(Empty), + #[serde(rename = "disconnect")] + Disconnect { reason: String }, + #[serde(untagged)] + Message(T), +} + +// Awful hack to get serde_json to generate an empty "{}" instead of using "null" +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +#[serde(deny_unknown_fields)] +struct Empty {} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[serde(rename_all = "snake_case", tag = "status", content = "response")] +enum Reply { + Ok(OkReply), + Error { reason: ErrorReply }, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(untagged)] +enum OkReply { + Message(T), + NoMessage(Empty), +} + +/// This represents the info we have about the error +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ErrorReply { + #[serde(rename = "unmatched topic")] + UnmatchedTopic, + TokenExpired, + NotFound, + Offline, + Disabled, + #[serde(other)] + Other, +} + impl PhoenixMessage { pub fn new(topic: impl Into, payload: T, reference: Option) -> Self { Self { @@ -358,32 +384,25 @@ impl PhoenixMessage { Self { topic: topic.into(), // There has to be a better way :\ - payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message( - payload, - )))), + payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))), reference: reference.into(), } } pub fn new_err_reply( topic: impl Into, - error: ErrorInfo, + reason: ErrorReply, reference: impl Into>, ) -> Self { Self { topic: topic.into(), // There has to be a better way :\ - payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(error))), + payload: Payload::Reply(Reply::Error { reason }), reference: reference.into(), } } } -// Awful hack to get serde_json to generate an empty "{}" instead of using "null" -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] -#[serde(deny_unknown_fields)] -struct Empty {} - #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(rename_all = "snake_case", tag = "event", content = "payload")] enum EgressControlMessage { @@ -391,54 +410,6 @@ enum EgressControlMessage { Heartbeat(Empty), } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum ReplyMessage { - PhxReply(PhxReply), - PhxError(Empty), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(untagged)] -enum OkReply { - Message(T), - NoMessage(Empty), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub struct UnknownError(pub String); - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub enum KnownError { - #[serde(rename = "unmatched topic")] - UnmatchedTopic, - #[serde(rename = "token_expired")] - TokenExpired, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(untagged)] -pub enum Reason { - Known(KnownError), - Unknown(UnknownError), -} - -/// This represents the info we have about the error -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum ErrorInfo { - Reason(Reason), - Offline, - Disabled, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "status", content = "response")] -enum PhxReply { - Ok(OkReply), - Error(ErrorInfo), -} - /// You can use this sender to send messages through a `PhoenixChannel`. /// /// Messages won't be sent unless [PhoenixChannel::start] is running, internally @@ -539,30 +510,27 @@ impl PhoenixSender { #[cfg(test)] mod tests { - use crate::control::{ - ControlMessage, Empty, ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, - ReplyMessage::PhxReply, UnknownError, - }; + use super::*; #[test] fn unmatched_topic_reply() { let actual_reply = r#" { - "event":"phx_reply", - "ref":"12", - "topic":"client", + "event": "phx_reply", + "ref": "12", + "topic": "client", "payload":{ - "status":"error", + "status": "error", "response":{ - "reason":"unmatched topic" + "reason": "unmatched topic" } } } "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason( - Reason::Known(KnownError::UnmatchedTopic), - )))); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::UnmatchedTopic, + }); assert_eq!(actual_reply, expected_reply); } @@ -577,7 +545,7 @@ mod tests { } "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::PhxClose(Empty {})); + let expected_reply = Payload::<(), ()>::Close(Empty {}); assert_eq!(actual_reply, expected_reply); } @@ -592,8 +560,30 @@ mod tests { } "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::ControlMessage(ControlMessage::Disconnect { + let expected_reply = Payload::<(), ()>::Disconnect { reason: "token_expired".to_string(), + }; + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn not_found() { + let actual_reply = r#" + { + "event": "phx_reply", + "ref": null, + "topic": "client", + "payload": { + "status": "error", + "response": { + "reason": "not_found" + } + } + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::NotFound, }); assert_eq!(actual_reply, expected_reply); } @@ -602,21 +592,21 @@ mod tests { fn unexpected_error_reply() { let actual_reply = r#" { - "event":"phx_reply", - "ref":"12", - "topic":"client", - "payload":{ - "status":"error", - "response":{ - "reason":"bad reply" + "event": "phx_reply", + "ref": "12", + "topic": "client", + "payload": { + "status": "error", + "response": { + "reason": "bad reply" } } } "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason( - Reason::Unknown(UnknownError("bad reply".to_string())), - )))); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::Other, + }); assert_eq!(actual_reply, expected_reply); } } diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index 5e93d51f4..6981694dd 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -13,8 +13,7 @@ use std::net::IpAddr; #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] pub struct InitGateway { pub interface: Interface, - pub ipv4_masquerade_enabled: bool, - pub ipv6_masquerade_enabled: bool, + pub config: Config, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] @@ -22,6 +21,12 @@ pub struct Actor { pub id: ActorId, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +pub struct Config { + pub ipv4_masquerade_enabled: bool, + pub ipv6_masquerade_enabled: bool, +} + #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Client { pub id: ClientId, @@ -141,11 +146,10 @@ pub struct ConnectionReady { #[cfg(test)] mod test { - use connlib_shared::{control::PhoenixMessage, messages::Interface}; + use super::*; + use connlib_shared::control::PhoenixMessage; use phoenix_channel::InitMessage; - use super::{IngressMessages, InitGateway}; - #[test] fn request_connection_message() { let message = r#"{ @@ -206,11 +210,13 @@ mod test { ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(), upstream_dns: vec![], }, - ipv4_masquerade_enabled: true, - ipv6_masquerade_enabled: true, + config: Config { + ipv4_masquerade_enabled: true, + ipv6_masquerade_enabled: true, + }, }); - let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}"#; + let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"config":{"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}}"#; let ingress_message = serde_json::from_str::>(message).unwrap(); assert_eq!(m, ingress_message); } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 37d0d0b20..926da3ee0 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -344,9 +344,7 @@ where })) } }, - Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error( - ErrorInfo::Reason(reason), - ))) => { + Payload::Reply(Reply::Error { reason }) => { return Poll::Ready(Ok(Event::ErrorResponse { topic: message.topic, req_id: OutboundRequestId( @@ -355,9 +353,7 @@ where reason, })); } - Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message( - reply, - )))) => { + Payload::Reply(Reply::Ok(OkReply::Message(reply))) => { let req_id = OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); @@ -376,18 +372,7 @@ where res: reply, })); } - Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error( - ErrorInfo::Offline, - ))) => { - tracing::warn!( - "Received offline error for request {:?}", - message.reference - ); - continue; - } - Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok( - OkReply::NoMessage(Empty {}), - ))) => { + Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => { let id = OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); @@ -402,20 +387,20 @@ where continue; } - Payload::Reply(ReplyMessage::PhxError(Empty {})) => { + Payload::Error(Empty {}) => { return Poll::Ready(Ok(Event::ErrorResponse { topic: message.topic, req_id: OutboundRequestId( message.reference.ok_or(Error::MissingReplyId)?, ), - reason: "unknown error (bad event?)".to_owned(), + reason: ErrorReply::Other, })) } - Payload::ControlMessage(ControlMessage::PhxClose(_)) => { + Payload::Close(Empty {}) => { self.reconnect_on_transient_error(Error::CloseMessage); continue; } - Payload::ControlMessage(ControlMessage::Disconnect { reason }) => { + Payload::Disconnect { reason } => { return Poll::Ready(Ok(Event::Disconnect(reason))); } } @@ -523,7 +508,7 @@ pub enum Event { ErrorResponse { topic: String, req_id: OutboundRequestId, - reason: String, + reason: ErrorReply, }, /// The server sent us a message, most likely this is a broadcast to all connected clients. InboundMessage { @@ -538,26 +523,9 @@ pub enum Event { Disconnect(String), } -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] -#[serde(untagged)] -enum Payload { - // We might want other type for the reply message - // but that makes everything even more convoluted! - // and we need to think how to make this whole mess less convoluted. - Reply(ReplyMessage), - ControlMessage(ControlMessage), - Message(T), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum ControlMessage { - PhxClose(Empty), - Disconnect { reason: String }, -} - #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] pub struct PhoenixMessage { + // TODO: we should use a newtype pattern for topics topic: String, #[serde(flatten)] payload: Payload, @@ -565,6 +533,54 @@ pub struct PhoenixMessage { reference: Option, } +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[serde(tag = "event", content = "payload")] +enum Payload { + #[serde(rename = "phx_reply")] + Reply(Reply), + #[serde(rename = "phx_error")] + Error(Empty), + #[serde(rename = "phx_close")] + Close(Empty), + #[serde(rename = "disconnect")] + Disconnect { reason: String }, + #[serde(untagged)] + Message(T), +} + +// Awful hack to get serde_json to generate an empty "{}" instead of using "null" +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +#[serde(deny_unknown_fields)] +struct Empty {} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[serde(rename_all = "snake_case", tag = "status", content = "response")] +enum Reply { + Ok(OkReply), + Error { reason: ErrorReply }, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(untagged)] +enum OkReply { + Message(T), + NoMessage(Empty), +} + +/// This represents the info we have about the error +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ErrorReply { + #[serde(rename = "unmatched topic")] + UnmatchedTopic, + TokenExpired, + NotFound, + Offline, + Disabled, + #[serde(other)] + Other, +} + impl PhoenixMessage { pub fn new(topic: impl Into, payload: T, reference: u64) -> Self { Self { @@ -609,11 +625,6 @@ fn make_request(secret_url: Secret, user_agent: String) -> Result { @@ -621,34 +632,6 @@ enum EgressControlMessage { Heartbeat(Empty), } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum ReplyMessage { - PhxReply(PhxReply), - PhxError(Empty), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(untagged)] -enum OkReply { - Message(T), - NoMessage(Empty), -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -enum ErrorInfo { - Reason(String), - Offline, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case", tag = "status", content = "response")] -enum PhxReply { - Ok(OkReply), - Error(ErrorInfo), -} - #[cfg(test)] mod tests { use super::*; @@ -662,14 +645,14 @@ mod tests { #[test] fn can_deserialize_inbound_message() { let msg = r#"{ - "topic": "room:lobby", - "ref": null, - "payload": { - "hello": "world" - }, - "join_ref": null, - "event": "shout" -}"#; + "topic": "room:lobby", + "ref": null, + "payload": { + "hello": "world" + }, + "join_ref": null, + "event": "shout" + }"#; let msg = serde_json::from_str::>(msg).unwrap(); @@ -698,4 +681,102 @@ mod tests { Payload::Message(InitMessage::Init(EmptyInit {})) ); } + + #[test] + fn unmatched_topic_reply() { + let actual_reply = r#" + { + "event": "phx_reply", + "ref": "12", + "topic": "client", + "payload":{ + "status": "error", + "response":{ + "reason": "unmatched topic" + } + } + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::UnmatchedTopic, + }); + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn phx_close() { + let actual_reply = r#" + { + "event": "phx_close", + "ref": null, + "topic": "client", + "payload": {} + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Close(Empty {}); + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn token_expired() { + let actual_reply = r#" + { + "event": "disconnect", + "ref": null, + "topic": "client", + "payload": { "reason": "token_expired" } + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Disconnect { + reason: "token_expired".to_string(), + }; + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn not_found() { + let actual_reply = r#" + { + "event": "phx_reply", + "ref": null, + "topic": "client", + "payload": { + "status": "error", + "response": { + "reason": "not_found" + } + } + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::NotFound, + }); + assert_eq!(actual_reply, expected_reply); + } + + #[test] + fn unexpected_error_reply() { + let actual_reply = r#" + { + "event": "phx_reply", + "ref": "12", + "topic": "client", + "payload": { + "status": "error", + "response": { + "reason": "bad reply" + } + } + } + "#; + let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); + let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + reason: ErrorReply::Other, + }); + assert_eq!(actual_reply, expected_reply); + } } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 2ac44372b..6a7f33af9 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -516,7 +516,7 @@ where req_id, reason, }))) => { - tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {reason}"); + tracing::warn!(target: "relay", "Request with ID {req_id} on topic {topic} failed: {reason:?}"); continue; } Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {