diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index 011eedc39..7c0bf3619 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -96,7 +96,7 @@ impl Eventloop { pub(crate) fn new( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, cmd_rx: mpsc::UnboundedReceiver, resource_list_sender: watch::Sender>, tun_config_sender: watch::Sender>, @@ -432,7 +432,7 @@ impl Eventloop { } async fn phoenix_channel_event_loop( - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, ) { @@ -449,7 +449,7 @@ async fn phoenix_channel_event_loop( break; } } - Either::Left((Ok(phoenix_channel::Event::SuccessResponse { res: (), .. }), _)) => {} + Either::Left((Ok(phoenix_channel::Event::SuccessResponse { .. }), _)) => {} Either::Left((Ok(phoenix_channel::Event::ErrorResponse { res, req_id, topic }), _)) => { match res { ErrorReply::Disabled => { diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index f1a53b8fb..cdc0f1ee5 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -57,7 +57,7 @@ impl Session { pub fn connect( tcp_socket_factory: Arc>, udp_socket_factory: Arc>, - portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, handle: tokio::runtime::Handle, ) -> (Self, EventStream) { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index 0846acc27..e0607356a 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -36,7 +36,7 @@ pub use tokio_tungstenite::tungstenite::http::StatusCode; const MAX_BUFFERED_MESSAGES: usize = 32; // Chosen pretty arbitrarily. If we are connected, these should never build up. -pub struct PhoenixChannel { +pub struct PhoenixChannel { state: State, waker: Option, pending_joins: VecDeque, @@ -46,7 +46,7 @@ pub struct PhoenixChannel { heartbeat: tokio::time::Interval, - _phantom: PhantomData<(TInboundMsg, TOutboundRes)>, + _phantom: PhantomData, pending_join_requests: HashMap, @@ -246,12 +246,10 @@ impl fmt::Display for OutboundRequestId { #[error("Cannot close websocket while we are connecting")] pub struct Connecting; -impl - PhoenixChannel +impl PhoenixChannel where TInitReq: Serialize + Clone, TInboundMsg: DeserializeOwned, - TOutboundRes: DeserializeOwned, TFinish: IntoIterator, { /// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state. @@ -374,10 +372,7 @@ where Ok(()) } - pub fn poll( - &mut self, - cx: &mut Context, - ) -> Poll, Error>> { + pub fn poll(&mut self, cx: &mut Context) -> Poll, Error>> { loop { // First, check if we are connected. let stream = match &mut self.state { @@ -547,20 +542,21 @@ where tracing::trace!(target: "wire::api::recv", %message); - let message = match serde_json::from_str::< - PhoenixMessage, - >(&message) - { - Ok(m) => m, - Err(e) if e.is_io() || e.is_eof() => { - self.reconnect_on_transient_error(InternalError::Serde(e)); - continue; - } - Err(e) => { - tracing::warn!("Failed to deserialize message: {}", err_with_src(&e)); - continue; - } - }; + let message = + match serde_json::from_str::>(&message) { + Ok(m) => m, + Err(e) if e.is_io() || e.is_eof() => { + self.reconnect_on_transient_error(InternalError::Serde(e)); + continue; + } + Err(e) => { + tracing::warn!( + "Failed to deserialize message: {}", + err_with_src(&e) + ); + continue; + } + }; match (message.payload, message.reference) { (Payload::Message(msg), _) => { @@ -586,11 +582,10 @@ where res: reason, })); } - (Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => { + (Payload::Reply(Reply::Ok(OkReply::Message(()))), Some(req_id)) => { return Poll::Ready(Ok(Event::SuccessResponse { topic: message.topic, req_id, - res: reply, })); } (Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => { @@ -710,12 +705,10 @@ where } #[derive(Debug)] -pub enum Event { +pub enum Event { SuccessResponse { topic: String, req_id: OutboundRequestId, - /// The response received for an outbound request. - res: TOutboundRes, }, ErrorResponse { topic: String, @@ -741,20 +734,20 @@ pub enum Event { } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] -pub struct PhoenixMessage { +pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics topic: String, #[serde(flatten)] - payload: Payload, + payload: Payload, #[serde(rename = "ref")] reference: Option, } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] #[serde(tag = "event", content = "payload")] -enum Payload { +enum Payload { #[serde(rename = "phx_reply")] - Reply(Reply), + Reply(Reply), #[serde(rename = "phx_error")] Error(Empty), #[serde(rename = "phx_close")] @@ -772,8 +765,8 @@ struct Empty {} #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] #[serde(rename_all = "snake_case", tag = "status", content = "response")] -enum Reply { - Ok(OkReply), +enum Reply { + Ok(OkReply<()>), // We never expect responses for our requests. Error { reason: ErrorReply }, } @@ -814,7 +807,7 @@ pub enum DisconnectReason { TokenExpired, } -impl PhoenixMessage { +impl PhoenixMessage { pub fn new_message( topic: impl Into, payload: T, @@ -827,14 +820,10 @@ impl PhoenixMessage { } } - pub fn new_ok_reply( - topic: impl Into, - payload: R, - reference: Option, - ) -> Self { + pub fn new_ok_reply(topic: impl Into, reference: Option) -> Self { Self { topic: topic.into(), - payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))), + payload: Payload::Reply(Reply::Ok(OkReply::Message(()))), reference, } } @@ -886,7 +875,7 @@ fn serialize_msg( payload: impl Serialize, request_id: OutboundRequestId, ) -> String { - serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( + serde_json::to_string(&PhoenixMessage::new_message( topic, payload, Some(request_id), @@ -916,7 +905,7 @@ mod tests { "event": "shout" }"#; - let msg = serde_json::from_str::>(msg).unwrap(); + let msg = serde_json::from_str::>(msg).unwrap(); assert_eq!(msg.topic, "room:lobby"); assert_eq!(msg.reference, None); @@ -943,8 +932,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::UnmatchedTopic, }); assert_eq!(actual_reply, expected_reply); @@ -960,8 +949,8 @@ mod tests { "payload": {} } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Close(Empty {}); + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Close(Empty {}); assert_eq!(actual_reply, expected_reply); } @@ -975,8 +964,8 @@ mod tests { "payload": { "reason": "token_expired" } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Disconnect { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Disconnect { reason: DisconnectReason::TokenExpired, }; assert_eq!(actual_reply, expected_reply); @@ -997,8 +986,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::Other, }); assert_eq!(actual_reply, expected_reply); @@ -1019,8 +1008,8 @@ mod tests { } } "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { + let actual_reply = serde_json::from_str::>(actual_reply).unwrap(); + let expected_reply = Payload::<()>::Reply(Reply::Error { reason: ErrorReply::InvalidVersion, }); assert_eq!(actual_reply, expected_reply); @@ -1030,7 +1019,7 @@ mod tests { fn disabled_err_reply() { let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#; - let actual = serde_json::from_str::>(json).unwrap(); + let actual = serde_json::from_str::>(json).unwrap(); let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None); assert_eq!(actual, expected) diff --git a/rust/connlib/phoenix-channel/tests/lib.rs b/rust/connlib/phoenix-channel/tests/lib.rs index cd87be312..5a9c0c350 100644 --- a/rust/connlib/phoenix-channel/tests/lib.rs +++ b/rust/connlib/phoenix-channel/tests/lib.rs @@ -70,7 +70,7 @@ async fn client_does_not_pipeline_messages() { .unwrap(), ); - let mut channel = PhoenixChannel::<(), InboundMsg, (), _>::disconnected( + let mut channel = PhoenixChannel::<(), InboundMsg, _>::disconnected( login_url, "test/1.0.0".to_owned(), "test", diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 0a95bc5d0..856a0bdd9 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -79,7 +79,7 @@ enum PortalCommand { impl Eventloop { pub(crate) fn new( tunnel: GatewayTunnel, - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, tun_device_manager: TunDeviceManager, ) -> Result { portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); @@ -637,7 +637,7 @@ impl Eventloop { } async fn phoenix_channel_event_loop( - mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, + mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>, event_tx: mpsc::Sender>, mut cmd_rx: mpsc::Receiver, ) { @@ -661,7 +661,7 @@ async fn phoenix_channel_event_loop( } Either::Left(( Ok( - phoenix_channel::Event::SuccessResponse { res: (), .. } + phoenix_channel::Event::SuccessResponse { .. } | phoenix_channel::Event::HeartbeatSent | phoenix_channel::Event::JoinedRoom { .. }, ), diff --git a/rust/relay/server/src/main.rs b/rust/relay/server/src/main.rs index e6d5964bd..7fbe952ad 100644 --- a/rust/relay/server/src/main.rs +++ b/rust/relay/server/src/main.rs @@ -416,7 +416,7 @@ struct Eventloop { sockets: Sockets, server: Server, - channel: Option>, + channel: Option>, sleep: Sleep, ebpf: Option, @@ -438,7 +438,7 @@ where fn new( server: Server, ebpf: Option, - channel: PhoenixChannel, + channel: PhoenixChannel, public_address: IpStack, last_heartbeat_sent: Arc>>, ) -> Result { @@ -719,9 +719,9 @@ where Ok(()) } - fn handle_portal_event(&mut self, event: phoenix_channel::Event) { + fn handle_portal_event(&mut self, event: phoenix_channel::Event) { match event { - Event::SuccessResponse { res: (), .. } => {} + Event::SuccessResponse { .. } => {} Event::JoinedRoom { topic } => { tracing::info!(target: "relay", "Successfully joined room '{topic}'"); }