From 69afe7121578efdf7318434b6d559117426b5891 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 17 Sep 2025 04:10:56 +0000 Subject: [PATCH] refactor(connlib): remove concept of "ReplyMessages" (#10361) In earlier versions of Firezone, the WebSocket protocol with the portal was using the request-response semantics built into Phoenix. This however is quite cumbersome to work with to due to the polymorphic nature of the protocol design. We ended up moving away from it and instead only use one-way messages where each event directly corresponds to a message type. However, we have never removed the capability reply messages from the `phoenix-channel` module, instead all usages just set it to `()`. We can simplify the code here by always setting this to `()`. Resolves: #7091 --- rust/client-shared/src/eventloop.rs | 6 +- rust/client-shared/src/lib.rs | 2 +- rust/connlib/phoenix-channel/src/lib.rs | 97 ++++++++++------------- rust/connlib/phoenix-channel/tests/lib.rs | 2 +- rust/gateway/src/eventloop.rs | 6 +- rust/relay/server/src/main.rs | 8 +- 6 files changed, 55 insertions(+), 66 deletions(-) diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 011eedc39..7c0bf3619 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -96,7 +96,7 @@ impl Eventloop { pub(crate) fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, cmd_rx: mpsc::UnboundedReceiver, resource_list_sender: watch::Sender>, tun_config_sender: watch::Sender>, @@ -432,7 +432,7 @@ impl Eventloop { } async fn phoenix_channel_event_loop( - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, ) { @@ -449,7 +449,7 @@ async fn phoenix_channel_event_loop( break; } } - Either::Left((Ok(phoenix_channel::Event::SuccessResponse { res: (), .. }), _)) => {} + Either::Left((Ok(phoenix_channel::Event::SuccessResponse { .. }), _)) => {} Either::Left((Ok(phoenix_channel::Event::ErrorResponse { res, req_id, topic }), _)) => { match res { ErrorReply::Disabled => { diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index f1a53b8fb..cdc0f1ee5 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -57,7 +57,7 @@ impl Session { pub fn connect( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, handle: tokio::runtime::Handle, ) -> (Self, EventStream) { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index 0846acc27..e0607356a 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -36,7 +36,7 @@ pub use tokio_tungstenite::tungstenite::http::StatusCode; const MAX_BUFFERED_MESSAGES: usize = 32; // Chosen pretty arbitrarily. If we are connected, these should never build up. -pub struct PhoenixChannel { +pub struct PhoenixChannel { state: State, waker: Option, pending_joins: VecDeque, @@ -46,7 +46,7 @@ pub struct PhoenixChannel { heartbeat: tokio::time::Interval, - _phantom: PhantomData<(TInboundMsg, TOutboundRes)>, + _phantom: PhantomData, pending_join_requests: HashMap, @@ -246,12 +246,10 @@ impl fmt::Display for OutboundRequestId { #[error("Cannot close websocket while we are connecting")] pub struct Connecting; -impl - PhoenixChannel +impl PhoenixChannel where TInitReq: Serialize + Clone, TInboundMsg: DeserializeOwned, - TOutboundRes: DeserializeOwned, TFinish: IntoIterator, { /// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state. @@ -374,10 +372,7 @@ where Ok(()) } - pub fn poll( - &mut self, - cx: &mut Context, - ) -> Poll, Error>> { + pub fn poll(&mut self, cx: &mut Context) -> Poll, Error>> { loop { // First, check if we are connected. let stream = match &mut self.state { @@ -547,20 +542,21 @@ where tracing::trace!(target: "wire::api::recv", %message); - let message = match serde_json::from_str::< - PhoenixMessage, - >(&message) - { - Ok(m) => m, - Err(e) if e.is_io() || e.is_eof() => { - self.reconnect_on_transient_error(InternalError::Serde(e)); - continue; - } - Err(e) => { - tracing::warn!("Failed to deserialize message: {}", err_with_src(&e)); - continue; - } - }; + let message = + match serde_json::from_str::>(&message) { + Ok(m) => m, + Err(e) if e.is_io() || e.is_eof() => { + self.reconnect_on_transient_error(InternalError::Serde(e)); + continue; + } + Err(e) => { + tracing::warn!( + "Failed to deserialize message: {}", + err_with_src(&e) + ); + continue; + } + }; match (message.payload, message.reference) { (Payload::Message(msg), _) => { @@ -586,11 +582,10 @@ where res: reason, })); } - (Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => { + (Payload::Reply(Reply::Ok(OkReply::Message(()))), Some(req_id)) => { return Poll::Ready(Ok(Event::SuccessResponse { topic: message.topic, req_id, - res: reply, })); } (Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => { @@ -710,12 +705,10 @@ where } #[derive(Debug)] -pub enum Event { +pub enum Event { SuccessResponse { topic: String, req_id: OutboundRequestId, - /// The response received for an outbound request. - res: TOutboundRes, }, ErrorResponse { topic: String, @@ -741,20 +734,20 @@ pub enum Event { } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] -pub struct PhoenixMessage { +pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics topic: String, #[serde(flatten)] - payload: Payload, + payload: Payload, #[serde(rename = "ref")] reference: Option, } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] #[serde(tag = "event", content = "payload")] -enum Payload { +enum Payload { #[serde(rename = "phx_reply")] - Reply(Reply), + Reply(Reply), #[serde(rename = "phx_error")] Error(Empty), #[serde(rename = "phx_close")] @@ -772,8 +765,8 @@ struct Empty {} #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] #[serde(rename_all = "snake_case", tag = "status", content = "response")] -enum Reply { - Ok(OkReply), +enum Reply { + Ok(OkReply<()>), // We never expect responses for our requests. Error { reason: ErrorReply }, } @@ -814,7 +807,7 @@ pub enum DisconnectReason { TokenExpired, } -impl PhoenixMessage { +impl PhoenixMessage { pub fn new_message( topic: impl Into, payload: T, @@ -827,14 +820,10 @@ impl PhoenixMessage { } } - pub fn new_ok_reply( - topic: impl Into, - payload: R, - reference: Option, - ) -> Self { + pub fn new_ok_reply(topic: impl Into, reference: Option) -> Self { Self { topic: topic.into(), - payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))), + payload: Payload::Reply(Reply::Ok(OkReply::Message(()))), reference, } } @@ -886,7 +875,7 @@ fn serialize_msg( payload: impl Serialize, request_id: OutboundRequestId, ) -> String { - serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( + serde_json::to_string(&PhoenixMessage::new_message( topic, payload, Some(request_id), @@ -916,7 +905,7 @@ mod tests { "event": "shout" }"#; - let msg = serde_json::from_str::>(msg).unwrap(); + let msg = serde_json::from_str::>(msg).unwrap(); assert_eq!(msg.topic, "room:lobby"); assert_eq!(msg.reference, None); @@ -943,8 +932,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::UnmatchedTopic, }); assert_eq!(actual_reply, expected_reply); @@ -960,8 +949,8 @@ mod tests { "payload": {} } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Close(Empty {}); + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Close(Empty {}); assert_eq!(actual_reply, expected_reply); } @@ -975,8 +964,8 @@ mod tests { "payload": { "reason": "token_expired" } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Disconnect { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Disconnect { reason: DisconnectReason::TokenExpired, }; assert_eq!(actual_reply, expected_reply); @@ -997,8 +986,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::Other, }); assert_eq!(actual_reply, expected_reply); @@ -1019,8 +1008,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::InvalidVersion, }); assert_eq!(actual_reply, expected_reply); @@ -1030,7 +1019,7 @@ mod tests { fn disabled_err_reply() { let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#; - let actual = serde_json::from_str::>(json).unwrap(); + let actual = serde_json::from_str::>(json).unwrap(); let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None); assert_eq!(actual, expected) diff --git a/rust/connlib/phoenix-channel/tests/lib.rs b/rust/connlib/phoenix-channel/tests/lib.rs index cd87be312..5a9c0c350 100644 --- a/rust/connlib/phoenix-channel/tests/lib.rs +++ b/rust/connlib/phoenix-channel/tests/lib.rs @@ -70,7 +70,7 @@ async fn client_does_not_pipeline_messages() { .unwrap(), ); - let mut channel = PhoenixChannel::<(), InboundMsg, (), _>::disconnected( + let mut channel = PhoenixChannel::<(), InboundMsg, _>::disconnected( login_url, "test/1.0.0".to_owned(), "test", diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 0a95bc5d0..856a0bdd9 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -79,7 +79,7 @@ enum PortalCommand { impl Eventloop { pub(crate) fn new( tunnel: GatewayTunnel, - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, tun_device_manager: TunDeviceManager, ) -> Result { portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); @@ -637,7 +637,7 @@ impl Eventloop { } async fn phoenix_channel_event_loop( - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, ) { @@ -661,7 +661,7 @@ async fn phoenix_channel_event_loop( } Either::Left(( Ok( - phoenix_channel::Event::SuccessResponse { res: (), .. } + phoenix_channel::Event::SuccessResponse { .. } | phoenix_channel::Event::HeartbeatSent | phoenix_channel::Event::JoinedRoom { .. }, ), diff --git a/rust/relay/server/src/main.rs b/rust/relay/server/src/main.rs index e6d5964bd..7fbe952ad 100644 --- a/rust/relay/server/src/main.rs +++ b/rust/relay/server/src/main.rs @@ -416,7 +416,7 @@ struct Eventloop { sockets: Sockets, server: Server, - channel: Option>, + channel: Option>, sleep: Sleep, ebpf: Option, @@ -438,7 +438,7 @@ where fn new( server: Server, ebpf: Option, - channel: PhoenixChannel, + channel: PhoenixChannel, public_address: IpStack, last_heartbeat_sent: Arc>>, ) -> Result { @@ -719,9 +719,9 @@ where Ok(()) } - fn handle_portal_event(&mut self, event: phoenix_channel::Event) { + fn handle_portal_event(&mut self, event: phoenix_channel::Event) { match event { - Event::SuccessResponse { res: (), .. } => {} + Event::SuccessResponse { .. } => {} Event::JoinedRoom { topic } => { tracing::info!(target: "relay", "Successfully joined room '{topic}'"); }