refactor(connlib): remove concept of "ReplyMessages" (#10361)

In earlier versions of Firezone, the WebSocket protocol with the portal
was using the request-response semantics built into Phoenix. This
however is quite cumbersome to work with to due to the polymorphic
nature of the protocol design.

We ended up moving away from it and instead only use one-way messages
where each event directly corresponds to a message type. However, we
have never removed the capability reply messages from the
`phoenix-channel` module, instead all usages just set it to `()`.

We can simplify the code here by always setting this to `()`.

Resolves: #7091
This commit is contained in:
Thomas Eizinger
2025-09-17 04:10:56 +00:00
committed by GitHub
parent b1ed2f8a5e
commit 69afe71215
6 changed files with 55 additions and 66 deletions

View File

@@ -96,7 +96,7 @@ impl Eventloop {
pub(crate) fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
cmd_rx: mpsc::UnboundedReceiver<Command>,
resource_list_sender: watch::Sender<Vec<ResourceView>>,
tun_config_sender: watch::Sender<Option<TunConfig>>,
@@ -432,7 +432,7 @@ impl Eventloop {
}
async fn phoenix_channel_event_loop(
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
mut cmd_rx: mpsc::Receiver<PortalCommand>,
) {
@@ -449,7 +449,7 @@ async fn phoenix_channel_event_loop(
break;
}
}
Either::Left((Ok(phoenix_channel::Event::SuccessResponse { res: (), .. }), _)) => {}
Either::Left((Ok(phoenix_channel::Event::SuccessResponse { .. }), _)) => {}
Either::Left((Ok(phoenix_channel::Event::ErrorResponse { res, req_id, topic }), _)) => {
match res {
ErrorReply::Disabled => {

View File

@@ -57,7 +57,7 @@ impl Session {
pub fn connect(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
handle: tokio::runtime::Handle,
) -> (Self, EventStream) {
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();

View File

@@ -36,7 +36,7 @@ pub use tokio_tungstenite::tungstenite::http::StatusCode;
const MAX_BUFFERED_MESSAGES: usize = 32; // Chosen pretty arbitrarily. If we are connected, these should never build up.
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
pub struct PhoenixChannel<TInitReq, TInboundMsg, TFinish> {
state: State,
waker: Option<Waker>,
pending_joins: VecDeque<String>,
@@ -46,7 +46,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
heartbeat: tokio::time::Interval,
_phantom: PhantomData<(TInboundMsg, TOutboundRes)>,
_phantom: PhantomData<TInboundMsg>,
pending_join_requests: HashMap<OutboundRequestId, Instant>,
@@ -246,12 +246,10 @@ impl fmt::Display for OutboundRequestId {
#[error("Cannot close websocket while we are connecting")]
pub struct Connecting;
impl<TInitReq, TInboundMsg, TOutboundRes, TFinish>
PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish>
impl<TInitReq, TInboundMsg, TFinish> PhoenixChannel<TInitReq, TInboundMsg, TFinish>
where
TInitReq: Serialize + Clone,
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
TFinish: IntoIterator<Item = (&'static str, String)>,
{
/// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state.
@@ -374,10 +372,7 @@ where
Ok(())
}
pub fn poll(
&mut self,
cx: &mut Context,
) -> Poll<Result<Event<TInboundMsg, TOutboundRes>, Error>> {
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<Event<TInboundMsg>, Error>> {
loop {
// First, check if we are connected.
let stream = match &mut self.state {
@@ -547,20 +542,21 @@ where
tracing::trace!(target: "wire::api::recv", %message);
let message = match serde_json::from_str::<
PhoenixMessage<TInboundMsg, TOutboundRes>,
>(&message)
{
Ok(m) => m,
Err(e) if e.is_io() || e.is_eof() => {
self.reconnect_on_transient_error(InternalError::Serde(e));
continue;
}
Err(e) => {
tracing::warn!("Failed to deserialize message: {}", err_with_src(&e));
continue;
}
};
let message =
match serde_json::from_str::<PhoenixMessage<TInboundMsg>>(&message) {
Ok(m) => m,
Err(e) if e.is_io() || e.is_eof() => {
self.reconnect_on_transient_error(InternalError::Serde(e));
continue;
}
Err(e) => {
tracing::warn!(
"Failed to deserialize message: {}",
err_with_src(&e)
);
continue;
}
};
match (message.payload, message.reference) {
(Payload::Message(msg), _) => {
@@ -586,11 +582,10 @@ where
res: reason,
}));
}
(Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => {
(Payload::Reply(Reply::Ok(OkReply::Message(()))), Some(req_id)) => {
return Poll::Ready(Ok(Event::SuccessResponse {
topic: message.topic,
req_id,
res: reply,
}));
}
(Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => {
@@ -710,12 +705,10 @@ where
}
#[derive(Debug)]
pub enum Event<TInboundMsg, TOutboundRes> {
pub enum Event<TInboundMsg> {
SuccessResponse {
topic: String,
req_id: OutboundRequestId,
/// The response received for an outbound request.
res: TOutboundRes,
},
ErrorResponse {
topic: String,
@@ -741,20 +734,20 @@ pub enum Event<TInboundMsg, TOutboundRes> {
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
pub struct PhoenixMessage<T> {
// TODO: we should use a newtype pattern for topics
topic: String,
#[serde(flatten)]
payload: Payload<T, R>,
payload: Payload<T>,
#[serde(rename = "ref")]
reference: Option<OutboundRequestId>,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(tag = "event", content = "payload")]
enum Payload<T, R> {
enum Payload<T> {
#[serde(rename = "phx_reply")]
Reply(Reply<R>),
Reply(Reply),
#[serde(rename = "phx_error")]
Error(Empty),
#[serde(rename = "phx_close")]
@@ -772,8 +765,8 @@ struct Empty {}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum Reply<T> {
Ok(OkReply<T>),
enum Reply {
Ok(OkReply<()>), // We never expect responses for our requests.
Error { reason: ErrorReply },
}
@@ -814,7 +807,7 @@ pub enum DisconnectReason {
TokenExpired,
}
impl<T, R> PhoenixMessage<T, R> {
impl<T> PhoenixMessage<T> {
pub fn new_message(
topic: impl Into<String>,
payload: T,
@@ -827,14 +820,10 @@ impl<T, R> PhoenixMessage<T, R> {
}
}
pub fn new_ok_reply(
topic: impl Into<String>,
payload: R,
reference: Option<OutboundRequestId>,
) -> Self {
pub fn new_ok_reply(topic: impl Into<String>, reference: Option<OutboundRequestId>) -> Self {
Self {
topic: topic.into(),
payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))),
payload: Payload::Reply(Reply::Ok(OkReply::Message(()))),
reference,
}
}
@@ -886,7 +875,7 @@ fn serialize_msg(
payload: impl Serialize,
request_id: OutboundRequestId,
) -> String {
serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
serde_json::to_string(&PhoenixMessage::new_message(
topic,
payload,
Some(request_id),
@@ -916,7 +905,7 @@ mod tests {
"event": "shout"
}"#;
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
let msg = serde_json::from_str::<PhoenixMessage<Msg>>(msg).unwrap();
assert_eq!(msg.topic, "room:lobby");
assert_eq!(msg.reference, None);
@@ -943,8 +932,8 @@ mod tests {
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
let expected_reply = Payload::<()>::Reply(Reply::Error {
reason: ErrorReply::UnmatchedTopic,
});
assert_eq!(actual_reply, expected_reply);
@@ -960,8 +949,8 @@ mod tests {
"payload": {}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Close(Empty {});
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
let expected_reply = Payload::<()>::Close(Empty {});
assert_eq!(actual_reply, expected_reply);
}
@@ -975,8 +964,8 @@ mod tests {
"payload": { "reason": "token_expired" }
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Disconnect {
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
let expected_reply = Payload::<()>::Disconnect {
reason: DisconnectReason::TokenExpired,
};
assert_eq!(actual_reply, expected_reply);
@@ -997,8 +986,8 @@ mod tests {
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
let expected_reply = Payload::<()>::Reply(Reply::Error {
reason: ErrorReply::Other,
});
assert_eq!(actual_reply, expected_reply);
@@ -1019,8 +1008,8 @@ mod tests {
}
}
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
let expected_reply = Payload::<()>::Reply(Reply::Error {
reason: ErrorReply::InvalidVersion,
});
assert_eq!(actual_reply, expected_reply);
@@ -1030,7 +1019,7 @@ mod tests {
fn disabled_err_reply() {
let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#;
let actual = serde_json::from_str::<PhoenixMessage<(), ()>>(json).unwrap();
let actual = serde_json::from_str::<PhoenixMessage<()>>(json).unwrap();
let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None);
assert_eq!(actual, expected)

View File

@@ -70,7 +70,7 @@ async fn client_does_not_pipeline_messages() {
.unwrap(),
);
let mut channel = PhoenixChannel::<(), InboundMsg, (), _>::disconnected(
let mut channel = PhoenixChannel::<(), InboundMsg, _>::disconnected(
login_url,
"test/1.0.0".to_owned(),
"test",

View File

@@ -79,7 +79,7 @@ enum PortalCommand {
impl Eventloop {
pub(crate) fn new(
tunnel: GatewayTunnel,
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
tun_device_manager: TunDeviceManager,
) -> Result<Self> {
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
@@ -637,7 +637,7 @@ impl Eventloop {
}
async fn phoenix_channel_event_loop(
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
mut cmd_rx: mpsc::Receiver<PortalCommand>,
) {
@@ -661,7 +661,7 @@ async fn phoenix_channel_event_loop(
}
Either::Left((
Ok(
phoenix_channel::Event::SuccessResponse { res: (), .. }
phoenix_channel::Event::SuccessResponse { .. }
| phoenix_channel::Event::HeartbeatSent
| phoenix_channel::Event::JoinedRoom { .. },
),

View File

@@ -416,7 +416,7 @@ struct Eventloop<R> {
sockets: Sockets,
server: Server<R>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, NoParams>>,
sleep: Sleep,
ebpf: Option<ebpf::Program>,
@@ -438,7 +438,7 @@ where
fn new(
server: Server<R>,
ebpf: Option<ebpf::Program>,
channel: PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>,
channel: PhoenixChannel<JoinMessage, IngressMessage, NoParams>,
public_address: IpStack,
last_heartbeat_sent: Arc<Mutex<Option<Instant>>>,
) -> Result<Self> {
@@ -719,9 +719,9 @@ where
Ok(())
}
fn handle_portal_event(&mut self, event: phoenix_channel::Event<IngressMessage, ()>) {
fn handle_portal_event(&mut self, event: phoenix_channel::Event<IngressMessage>) {
match event {
Event::SuccessResponse { res: (), .. } => {}
Event::SuccessResponse { .. } => {}
Event::JoinedRoom { topic } => {
tracing::info!(target: "relay", "Successfully joined room '{topic}'");
}