fix(connlib): wait for room join before sending messages (#9656)

To avoid race conditions, we wait for all room joins on the WebSocket to
be successful before sending any messages to the portal. This requires
us to split room join messages from other messages so we can still send
them separately.

Resolves: #9647
This commit is contained in:
Thomas Eizinger
2025-06-25 19:34:53 +02:00
committed by GitHub
parent bebc69e2bc
commit f435510dab

View File

@@ -38,6 +38,7 @@ const MAX_BUFFERED_MESSAGES: usize = 32; // Chosen pretty arbitrarily. If we are
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
state: State,
waker: Option<Waker>,
pending_joins: VecDeque<String>,
pending_messages: VecDeque<String>,
next_request_id: u64,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
@@ -277,6 +278,7 @@ where
state: State::Closed,
socket_factory,
waker: None,
pending_joins: VecDeque::with_capacity(MAX_BUFFERED_MESSAGES),
pending_messages: VecDeque::with_capacity(MAX_BUFFERED_MESSAGES),
_phantom: PhantomData,
heartbeat: tokio::time::interval(Duration::from_secs(30)),
@@ -294,8 +296,8 @@ where
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
pub fn join(&mut self, topic: impl Into<String>, payload: impl Serialize) {
let (request_id, msg) = self.make_message(topic, EgressControlMessage::PhxJoin(payload));
self.pending_messages.push_front(msg); // Must send the join message before all others.
self.pending_joins.push_back(msg);
self.pending_join_requests.insert(request_id);
}
@@ -463,7 +465,17 @@ where
// Priority 1: Keep local buffers small and send pending messages.
match stream.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {
if let Some(message) = self.pending_messages.pop_front() {
// Process join messages before other messages.
// Only process other messages if no room joins are pending.
let next_message = self.pending_joins.pop_front().or_else(|| {
if self.pending_join_requests.is_empty() {
return None;
}
self.pending_messages.pop_front()
});
if let Some(message) = next_message {
match stream.start_send_unpin(Message::Text(message.clone())) {
Ok(()) => {
tracing::trace!(target: "wire::api::send", %message);
@@ -488,6 +500,7 @@ where
self.reconnect_on_transient_error(InternalError::WebSocket(e));
}
}
continue;
}
}