mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): remove concept of "ReplyMessages" (#10361)
In earlier versions of Firezone, the WebSocket protocol with the portal was using the request-response semantics built into Phoenix. This however is quite cumbersome to work with to due to the polymorphic nature of the protocol design. We ended up moving away from it and instead only use one-way messages where each event directly corresponds to a message type. However, we have never removed the capability reply messages from the `phoenix-channel` module, instead all usages just set it to `()`. We can simplify the code here by always setting this to `()`. Resolves: #7091
This commit is contained in:
@@ -96,7 +96,7 @@ impl Eventloop {
|
||||
pub(crate) fn new(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
cmd_rx: mpsc::UnboundedReceiver<Command>,
|
||||
resource_list_sender: watch::Sender<Vec<ResourceView>>,
|
||||
tun_config_sender: watch::Sender<Option<TunConfig>>,
|
||||
@@ -432,7 +432,7 @@ impl Eventloop {
|
||||
}
|
||||
|
||||
async fn phoenix_channel_event_loop(
|
||||
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
|
||||
mut cmd_rx: mpsc::Receiver<PortalCommand>,
|
||||
) {
|
||||
@@ -449,7 +449,7 @@ async fn phoenix_channel_event_loop(
|
||||
break;
|
||||
}
|
||||
}
|
||||
Either::Left((Ok(phoenix_channel::Event::SuccessResponse { res: (), .. }), _)) => {}
|
||||
Either::Left((Ok(phoenix_channel::Event::SuccessResponse { .. }), _)) => {}
|
||||
Either::Left((Ok(phoenix_channel::Event::ErrorResponse { res, req_id, topic }), _)) => {
|
||||
match res {
|
||||
ErrorReply::Disabled => {
|
||||
|
||||
@@ -57,7 +57,7 @@ impl Session {
|
||||
pub fn connect(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> (Self, EventStream) {
|
||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||
|
||||
@@ -36,7 +36,7 @@ pub use tokio_tungstenite::tungstenite::http::StatusCode;
|
||||
|
||||
const MAX_BUFFERED_MESSAGES: usize = 32; // Chosen pretty arbitrarily. If we are connected, these should never build up.
|
||||
|
||||
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
|
||||
pub struct PhoenixChannel<TInitReq, TInboundMsg, TFinish> {
|
||||
state: State,
|
||||
waker: Option<Waker>,
|
||||
pending_joins: VecDeque<String>,
|
||||
@@ -46,7 +46,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
|
||||
|
||||
heartbeat: tokio::time::Interval,
|
||||
|
||||
_phantom: PhantomData<(TInboundMsg, TOutboundRes)>,
|
||||
_phantom: PhantomData<TInboundMsg>,
|
||||
|
||||
pending_join_requests: HashMap<OutboundRequestId, Instant>,
|
||||
|
||||
@@ -246,12 +246,10 @@ impl fmt::Display for OutboundRequestId {
|
||||
#[error("Cannot close websocket while we are connecting")]
|
||||
pub struct Connecting;
|
||||
|
||||
impl<TInitReq, TInboundMsg, TOutboundRes, TFinish>
|
||||
PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish>
|
||||
impl<TInitReq, TInboundMsg, TFinish> PhoenixChannel<TInitReq, TInboundMsg, TFinish>
|
||||
where
|
||||
TInitReq: Serialize + Clone,
|
||||
TInboundMsg: DeserializeOwned,
|
||||
TOutboundRes: DeserializeOwned,
|
||||
TFinish: IntoIterator<Item = (&'static str, String)>,
|
||||
{
|
||||
/// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state.
|
||||
@@ -374,10 +372,7 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Result<Event<TInboundMsg, TOutboundRes>, Error>> {
|
||||
pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<Event<TInboundMsg>, Error>> {
|
||||
loop {
|
||||
// First, check if we are connected.
|
||||
let stream = match &mut self.state {
|
||||
@@ -547,20 +542,21 @@ where
|
||||
|
||||
tracing::trace!(target: "wire::api::recv", %message);
|
||||
|
||||
let message = match serde_json::from_str::<
|
||||
PhoenixMessage<TInboundMsg, TOutboundRes>,
|
||||
>(&message)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) if e.is_io() || e.is_eof() => {
|
||||
self.reconnect_on_transient_error(InternalError::Serde(e));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to deserialize message: {}", err_with_src(&e));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let message =
|
||||
match serde_json::from_str::<PhoenixMessage<TInboundMsg>>(&message) {
|
||||
Ok(m) => m,
|
||||
Err(e) if e.is_io() || e.is_eof() => {
|
||||
self.reconnect_on_transient_error(InternalError::Serde(e));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to deserialize message: {}",
|
||||
err_with_src(&e)
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match (message.payload, message.reference) {
|
||||
(Payload::Message(msg), _) => {
|
||||
@@ -586,11 +582,10 @@ where
|
||||
res: reason,
|
||||
}));
|
||||
}
|
||||
(Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => {
|
||||
(Payload::Reply(Reply::Ok(OkReply::Message(()))), Some(req_id)) => {
|
||||
return Poll::Ready(Ok(Event::SuccessResponse {
|
||||
topic: message.topic,
|
||||
req_id,
|
||||
res: reply,
|
||||
}));
|
||||
}
|
||||
(Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => {
|
||||
@@ -710,12 +705,10 @@ where
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Event<TInboundMsg, TOutboundRes> {
|
||||
pub enum Event<TInboundMsg> {
|
||||
SuccessResponse {
|
||||
topic: String,
|
||||
req_id: OutboundRequestId,
|
||||
/// The response received for an outbound request.
|
||||
res: TOutboundRes,
|
||||
},
|
||||
ErrorResponse {
|
||||
topic: String,
|
||||
@@ -741,20 +734,20 @@ pub enum Event<TInboundMsg, TOutboundRes> {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct PhoenixMessage<T, R> {
|
||||
pub struct PhoenixMessage<T> {
|
||||
// TODO: we should use a newtype pattern for topics
|
||||
topic: String,
|
||||
#[serde(flatten)]
|
||||
payload: Payload<T, R>,
|
||||
payload: Payload<T>,
|
||||
#[serde(rename = "ref")]
|
||||
reference: Option<OutboundRequestId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
#[serde(tag = "event", content = "payload")]
|
||||
enum Payload<T, R> {
|
||||
enum Payload<T> {
|
||||
#[serde(rename = "phx_reply")]
|
||||
Reply(Reply<R>),
|
||||
Reply(Reply),
|
||||
#[serde(rename = "phx_error")]
|
||||
Error(Empty),
|
||||
#[serde(rename = "phx_close")]
|
||||
@@ -772,8 +765,8 @@ struct Empty {}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
|
||||
enum Reply<T> {
|
||||
Ok(OkReply<T>),
|
||||
enum Reply {
|
||||
Ok(OkReply<()>), // We never expect responses for our requests.
|
||||
Error { reason: ErrorReply },
|
||||
}
|
||||
|
||||
@@ -814,7 +807,7 @@ pub enum DisconnectReason {
|
||||
TokenExpired,
|
||||
}
|
||||
|
||||
impl<T, R> PhoenixMessage<T, R> {
|
||||
impl<T> PhoenixMessage<T> {
|
||||
pub fn new_message(
|
||||
topic: impl Into<String>,
|
||||
payload: T,
|
||||
@@ -827,14 +820,10 @@ impl<T, R> PhoenixMessage<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_ok_reply(
|
||||
topic: impl Into<String>,
|
||||
payload: R,
|
||||
reference: Option<OutboundRequestId>,
|
||||
) -> Self {
|
||||
pub fn new_ok_reply(topic: impl Into<String>, reference: Option<OutboundRequestId>) -> Self {
|
||||
Self {
|
||||
topic: topic.into(),
|
||||
payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))),
|
||||
payload: Payload::Reply(Reply::Ok(OkReply::Message(()))),
|
||||
reference,
|
||||
}
|
||||
}
|
||||
@@ -886,7 +875,7 @@ fn serialize_msg(
|
||||
payload: impl Serialize,
|
||||
request_id: OutboundRequestId,
|
||||
) -> String {
|
||||
serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
|
||||
serde_json::to_string(&PhoenixMessage::new_message(
|
||||
topic,
|
||||
payload,
|
||||
Some(request_id),
|
||||
@@ -916,7 +905,7 @@ mod tests {
|
||||
"event": "shout"
|
||||
}"#;
|
||||
|
||||
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
|
||||
let msg = serde_json::from_str::<PhoenixMessage<Msg>>(msg).unwrap();
|
||||
|
||||
assert_eq!(msg.topic, "room:lobby");
|
||||
assert_eq!(msg.reference, None);
|
||||
@@ -943,8 +932,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
||||
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<()>::Reply(Reply::Error {
|
||||
reason: ErrorReply::UnmatchedTopic,
|
||||
});
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
@@ -960,8 +949,8 @@ mod tests {
|
||||
"payload": {}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Close(Empty {});
|
||||
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<()>::Close(Empty {});
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
}
|
||||
|
||||
@@ -975,8 +964,8 @@ mod tests {
|
||||
"payload": { "reason": "token_expired" }
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Disconnect {
|
||||
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<()>::Disconnect {
|
||||
reason: DisconnectReason::TokenExpired,
|
||||
};
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
@@ -997,8 +986,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
||||
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<()>::Reply(Reply::Error {
|
||||
reason: ErrorReply::Other,
|
||||
});
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
@@ -1019,8 +1008,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
||||
let actual_reply = serde_json::from_str::<Payload<()>>(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<()>::Reply(Reply::Error {
|
||||
reason: ErrorReply::InvalidVersion,
|
||||
});
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
@@ -1030,7 +1019,7 @@ mod tests {
|
||||
fn disabled_err_reply() {
|
||||
let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#;
|
||||
|
||||
let actual = serde_json::from_str::<PhoenixMessage<(), ()>>(json).unwrap();
|
||||
let actual = serde_json::from_str::<PhoenixMessage<()>>(json).unwrap();
|
||||
let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None);
|
||||
|
||||
assert_eq!(actual, expected)
|
||||
|
||||
@@ -70,7 +70,7 @@ async fn client_does_not_pipeline_messages() {
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let mut channel = PhoenixChannel::<(), InboundMsg, (), _>::disconnected(
|
||||
let mut channel = PhoenixChannel::<(), InboundMsg, _>::disconnected(
|
||||
login_url,
|
||||
"test/1.0.0".to_owned(),
|
||||
"test",
|
||||
|
||||
@@ -79,7 +79,7 @@ enum PortalCommand {
|
||||
impl Eventloop {
|
||||
pub(crate) fn new(
|
||||
tunnel: GatewayTunnel,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
tun_device_manager: TunDeviceManager,
|
||||
) -> Result<Self> {
|
||||
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
|
||||
@@ -637,7 +637,7 @@ impl Eventloop {
|
||||
}
|
||||
|
||||
async fn phoenix_channel_event_loop(
|
||||
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
|
||||
mut portal: PhoenixChannel<(), IngressMessages, PublicKeyParam>,
|
||||
event_tx: mpsc::Sender<Result<IngressMessages, phoenix_channel::Error>>,
|
||||
mut cmd_rx: mpsc::Receiver<PortalCommand>,
|
||||
) {
|
||||
@@ -661,7 +661,7 @@ async fn phoenix_channel_event_loop(
|
||||
}
|
||||
Either::Left((
|
||||
Ok(
|
||||
phoenix_channel::Event::SuccessResponse { res: (), .. }
|
||||
phoenix_channel::Event::SuccessResponse { .. }
|
||||
| phoenix_channel::Event::HeartbeatSent
|
||||
| phoenix_channel::Event::JoinedRoom { .. },
|
||||
),
|
||||
|
||||
@@ -416,7 +416,7 @@ struct Eventloop<R> {
|
||||
sockets: Sockets,
|
||||
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>>,
|
||||
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, NoParams>>,
|
||||
sleep: Sleep,
|
||||
|
||||
ebpf: Option<ebpf::Program>,
|
||||
@@ -438,7 +438,7 @@ where
|
||||
fn new(
|
||||
server: Server<R>,
|
||||
ebpf: Option<ebpf::Program>,
|
||||
channel: PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>,
|
||||
channel: PhoenixChannel<JoinMessage, IngressMessage, NoParams>,
|
||||
public_address: IpStack,
|
||||
last_heartbeat_sent: Arc<Mutex<Option<Instant>>>,
|
||||
) -> Result<Self> {
|
||||
@@ -719,9 +719,9 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_portal_event(&mut self, event: phoenix_channel::Event<IngressMessage, ()>) {
|
||||
fn handle_portal_event(&mut self, event: phoenix_channel::Event<IngressMessage>) {
|
||||
match event {
|
||||
Event::SuccessResponse { res: (), .. } => {}
|
||||
Event::SuccessResponse { .. } => {}
|
||||
Event::JoinedRoom { topic } => {
|
||||
tracing::info!(target: "relay", "Successfully joined room '{topic}'");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user