mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
The `connlib-shared` crate has become a bit of a dependency magnet without a clear purpose. It hosts utilities like `get_user_agent`, messages for the client and gateway to communicate with the portal and domain types like `ResourceId`. To create a better dependency structure in our workspace, we repurpose `connlib-shared` as a `connlib-model` crate. Its purpose is to host domain-specific model types that multiple crates may want to use. For that purpose, we rename the `callbacks::ResourceDescription` type to `ResourceView`, designating that this is a _view_ onto a resource as seen by `connlib`. The message types which currently double up as connlib-internal model thus become an implementation detail of `firezone-tunnel` and shouldn't be used for anything else. --------- Signed-off-by: Reactor Scram <ReactorScram@users.noreply.github.com> Co-authored-by: Reactor Scram <ReactorScram@users.noreply.github.com>
943 lines
31 KiB
Rust
943 lines
31 KiB
Rust
mod get_user_agent;
|
|
mod heartbeat;
|
|
mod login_url;
|
|
|
|
use std::collections::{HashSet, VecDeque};
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::sync::atomic::AtomicU64;
|
|
use std::sync::Arc;
|
|
use std::{fmt, future, marker::PhantomData};
|
|
use std::{io, mem};
|
|
|
|
use backoff::backoff::Backoff;
|
|
use backoff::ExponentialBackoff;
|
|
use base64::Engine;
|
|
use futures::future::BoxFuture;
|
|
use futures::{FutureExt, SinkExt, StreamExt};
|
|
use heartbeat::{Heartbeat, MissedLastHeartbeat};
|
|
use rand_core::{OsRng, RngCore};
|
|
use secrecy::{ExposeSecret, Secret};
|
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
use socket_factory::{SocketFactory, TcpSocket, TcpStream};
|
|
use std::task::{Context, Poll, Waker};
|
|
use tokio_tungstenite::client_async_tls;
|
|
use tokio_tungstenite::tungstenite::http::StatusCode;
|
|
use tokio_tungstenite::{
|
|
tungstenite::{handshake::client::Request, Message},
|
|
MaybeTlsStream, WebSocketStream,
|
|
};
|
|
use url::{Host, Url};
|
|
|
|
pub use get_user_agent::get_user_agent;
|
|
pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError};
|
|
|
|
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
|
state: State,
|
|
waker: Option<Waker>,
|
|
pending_messages: VecDeque<String>,
|
|
next_request_id: Arc<AtomicU64>,
|
|
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
|
|
|
heartbeat: Heartbeat,
|
|
|
|
_phantom: PhantomData<(TInboundMsg, TOutboundRes)>,
|
|
|
|
pending_join_requests: HashSet<OutboundRequestId>,
|
|
|
|
// Stored here to allow re-connecting.
|
|
url: Secret<LoginUrl>,
|
|
user_agent: String,
|
|
reconnect_backoff: ExponentialBackoff,
|
|
|
|
resolved_addresses: Vec<IpAddr>,
|
|
|
|
login: &'static str,
|
|
init_req: TInitReq,
|
|
}
|
|
|
|
enum State {
|
|
Connected(WebSocketStream<MaybeTlsStream<TcpStream>>),
|
|
Connecting(
|
|
BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError>>,
|
|
),
|
|
Closing(WebSocketStream<MaybeTlsStream<TcpStream>>),
|
|
Closed,
|
|
}
|
|
|
|
impl State {
|
|
fn connect(
|
|
url: Secret<LoginUrl>,
|
|
user_agent: String,
|
|
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
|
) -> Self {
|
|
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
|
|
}
|
|
}
|
|
|
|
async fn create_and_connect_websocket(
|
|
url: Secret<LoginUrl>,
|
|
user_agent: String,
|
|
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
|
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
|
|
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
|
|
|
|
let (stream, _) = client_async_tls(make_request(url, user_agent), socket)
|
|
.await
|
|
.map_err(InternalError::WebSocket)?;
|
|
|
|
Ok(stream)
|
|
}
|
|
|
|
async fn make_socket(
|
|
url: &Url,
|
|
socket_factory: &dyn SocketFactory<TcpSocket>,
|
|
) -> Result<TcpStream, InternalError> {
|
|
let port = url
|
|
.port_or_known_default()
|
|
.expect("scheme to be http, https, ws or wss");
|
|
let addrs: Vec<SocketAddr> = match url.host().ok_or(InternalError::InvalidUrl)? {
|
|
Host::Domain(n) => tokio::net::lookup_host((n, port))
|
|
.await
|
|
.map_err(|_| InternalError::InvalidUrl)?
|
|
.collect(),
|
|
Host::Ipv6(ip) => {
|
|
vec![(ip, port).into()]
|
|
}
|
|
Host::Ipv4(ip) => {
|
|
vec![(ip, port).into()]
|
|
}
|
|
};
|
|
|
|
let mut last_error = None;
|
|
for addr in addrs {
|
|
let Ok(socket) = socket_factory(&addr) else {
|
|
continue;
|
|
};
|
|
|
|
match socket.connect(addr).await {
|
|
Ok(socket) => return Ok(socket),
|
|
Err(e) => {
|
|
last_error = Some(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
let Some(err) = last_error else {
|
|
return Err(InternalError::InvalidUrl);
|
|
};
|
|
|
|
Err(InternalError::SocketConnection(err))
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum Error {
|
|
#[error("client error: {0}")]
|
|
Client(StatusCode),
|
|
#[error("token expired")]
|
|
TokenExpired,
|
|
#[error("max retries reached")]
|
|
MaxRetriesReached,
|
|
#[error("login failed: {0}")]
|
|
LoginFailed(ErrorReply),
|
|
}
|
|
|
|
impl Error {
|
|
pub fn is_authentication_error(&self) -> bool {
|
|
match self {
|
|
Error::Client(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN,
|
|
Error::TokenExpired => true,
|
|
Error::MaxRetriesReached => false,
|
|
Error::LoginFailed(_) => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
enum InternalError {
|
|
WebSocket(tokio_tungstenite::tungstenite::Error),
|
|
Serde(serde_json::Error),
|
|
MissedHeartbeat,
|
|
CloseMessage,
|
|
StreamClosed,
|
|
InvalidUrl,
|
|
SocketConnection(std::io::Error),
|
|
}
|
|
|
|
impl fmt::Display for InternalError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => {
|
|
let status = http.status();
|
|
let body = http
|
|
.body()
|
|
.as_deref()
|
|
.map(String::from_utf8_lossy)
|
|
.unwrap_or_default();
|
|
|
|
write!(f, "http error: {status} - {body}")
|
|
}
|
|
InternalError::WebSocket(e) => write!(f, "websocket connection failed: {e}"),
|
|
InternalError::Serde(e) => write!(f, "failed to deserialize message: {e}"),
|
|
InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"),
|
|
InternalError::CloseMessage => write!(f, "portal closed the websocket connection"),
|
|
InternalError::StreamClosed => write!(f, "websocket stream was closed"),
|
|
InternalError::InvalidUrl => write!(f, "failed to resolve url"),
|
|
InternalError::SocketConnection(e) => write!(f, "failed to connect socket: {e}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A strict-monotonically increasing ID for outbound requests.
|
|
#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize, PartialOrd, Ord)]
|
|
pub struct OutboundRequestId(u64);
|
|
|
|
impl OutboundRequestId {
|
|
// Should only be used for unit-testing.
|
|
pub fn for_test(id: u64) -> Self {
|
|
Self(id)
|
|
}
|
|
|
|
/// Internal function to make a copy.
|
|
///
|
|
/// Not exposed publicly because these IDs are meant to be unique.
|
|
pub(crate) fn copy(&self) -> Self {
|
|
Self(self.0)
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for OutboundRequestId {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "OutReq-{}", self.0)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
#[error("Cannot close websocket while we are connecting")]
|
|
pub struct Connecting;
|
|
|
|
impl<TInitReq, TInboundMsg, TOutboundRes> PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>
|
|
where
|
|
TInitReq: Serialize + Clone,
|
|
TInboundMsg: DeserializeOwned,
|
|
TOutboundRes: DeserializeOwned,
|
|
{
|
|
/// Creates a new [PhoenixChannel] to the given endpoint.
|
|
///
|
|
/// The provided URL must contain a host.
|
|
/// Additionally, you must already provide any query parameters required for authentication.
|
|
pub fn connect(
|
|
url: Secret<LoginUrl>,
|
|
user_agent: String,
|
|
login: &'static str,
|
|
init_req: TInitReq,
|
|
reconnect_backoff: ExponentialBackoff,
|
|
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
|
) -> io::Result<Self> {
|
|
let next_request_id = Arc::new(AtomicU64::new(0));
|
|
|
|
// Statically resolve the host in the URL to a set of addresses.
|
|
// We don't use these directly because we need to connect to the domain via TLS which requires a hostname.
|
|
// We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs.
|
|
let resolved_addresses = url
|
|
.expose_secret()
|
|
.inner()
|
|
.socket_addrs(|| None)?
|
|
.iter()
|
|
.map(|addr| addr.ip())
|
|
.collect();
|
|
|
|
tracing::debug!(host = %url.expose_secret().host(), %user_agent, "Connecting to portal");
|
|
|
|
Ok(Self {
|
|
reconnect_backoff,
|
|
url: url.clone(),
|
|
user_agent: user_agent.clone(),
|
|
state: State::connect(url, user_agent, socket_factory.clone()),
|
|
socket_factory,
|
|
waker: None,
|
|
pending_messages: Default::default(),
|
|
_phantom: PhantomData,
|
|
heartbeat: Heartbeat::new(
|
|
heartbeat::INTERVAL,
|
|
heartbeat::TIMEOUT,
|
|
next_request_id.clone(),
|
|
),
|
|
next_request_id,
|
|
pending_join_requests: Default::default(),
|
|
login,
|
|
init_req,
|
|
resolved_addresses,
|
|
})
|
|
}
|
|
|
|
/// Returns the addresses that have been resolved for our server host.
|
|
pub fn resolved_addresses(&self) -> Vec<IpAddr> {
|
|
self.resolved_addresses.clone()
|
|
}
|
|
|
|
/// The host we are connecting / connected to.
|
|
pub fn server_host(&self) -> &str {
|
|
self.url.expose_secret().host()
|
|
}
|
|
|
|
/// Join the provided room.
|
|
///
|
|
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
|
|
pub fn join(&mut self, topic: impl Into<String>, payload: impl Serialize) {
|
|
let (request_id, msg) = self.make_message(topic, EgressControlMessage::PhxJoin(payload));
|
|
self.pending_messages.push_front(msg); // Must send the join message before all others.
|
|
|
|
self.pending_join_requests.insert(request_id);
|
|
}
|
|
|
|
/// Send a message to a topic.
|
|
pub fn send(&mut self, topic: impl Into<String>, message: impl Serialize) -> OutboundRequestId {
|
|
let (id, msg) = self.make_message(topic, message);
|
|
self.pending_messages.push_back(msg);
|
|
|
|
id
|
|
}
|
|
|
|
/// Reconnects to the portal.
|
|
pub fn reconnect(&mut self) {
|
|
// 1. Reset the backoff.
|
|
self.reconnect_backoff.reset();
|
|
|
|
// 2. Set state to `Connecting` without a timer.
|
|
let url = self.url.clone();
|
|
let user_agent = self.user_agent.clone();
|
|
self.state = State::connect(url, user_agent, self.socket_factory.clone());
|
|
|
|
// 3. In case we were already re-connecting, we need to wake the suspended task.
|
|
if let Some(waker) = self.waker.take() {
|
|
waker.wake();
|
|
}
|
|
}
|
|
|
|
/// Initiate a graceful close of the connection.
|
|
pub fn close(&mut self) -> Result<(), Connecting> {
|
|
tracing::info!("Closing connection to portal");
|
|
|
|
match mem::replace(&mut self.state, State::Closed) {
|
|
State::Connecting(_) => return Err(Connecting),
|
|
State::Closing(stream) | State::Connected(stream) => {
|
|
self.state = State::Closing(stream);
|
|
}
|
|
State::Closed => {}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn poll(
|
|
&mut self,
|
|
cx: &mut Context,
|
|
) -> Poll<Result<Event<TInboundMsg, TOutboundRes>, Error>> {
|
|
loop {
|
|
// First, check if we are connected.
|
|
let stream = match &mut self.state {
|
|
State::Closed => return Poll::Ready(Ok(Event::Closed)),
|
|
State::Closing(stream) => match stream.poll_close_unpin(cx) {
|
|
Poll::Ready(Ok(())) => {
|
|
tracing::info!("Closed websocket connection to portal");
|
|
|
|
self.state = State::Closed;
|
|
|
|
return Poll::Ready(Ok(Event::Closed));
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
tracing::warn!("Error while closing websocket: {e}");
|
|
|
|
return Poll::Ready(Ok(Event::Closed));
|
|
}
|
|
Poll::Pending => return Poll::Pending,
|
|
},
|
|
State::Connected(stream) => stream,
|
|
State::Connecting(future) => match future.poll_unpin(cx) {
|
|
Poll::Ready(Ok(stream)) => {
|
|
self.reconnect_backoff.reset();
|
|
self.heartbeat.reset();
|
|
self.state = State::Connected(stream);
|
|
|
|
let host = self.url.expose_secret().host();
|
|
|
|
tracing::info!(%host, "Connected to portal");
|
|
self.join(self.login, self.init_req.clone());
|
|
|
|
continue;
|
|
}
|
|
Poll::Ready(Err(InternalError::WebSocket(
|
|
tokio_tungstenite::tungstenite::Error::Http(r),
|
|
))) if r.status().is_client_error() => {
|
|
return Poll::Ready(Err(Error::Client(r.status())));
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
let Some(backoff) = self.reconnect_backoff.next_backoff() else {
|
|
tracing::warn!("Reconnect backoff expired");
|
|
return Poll::Ready(Err(Error::MaxRetriesReached));
|
|
};
|
|
|
|
let secret_url = self.url.clone();
|
|
let user_agent = self.user_agent.clone();
|
|
let socket_factory = self.socket_factory.clone();
|
|
|
|
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
|
|
|
|
self.state = State::Connecting(Box::pin(async move {
|
|
tokio::time::sleep(backoff).await;
|
|
create_and_connect_websocket(secret_url, user_agent, socket_factory)
|
|
.await
|
|
}));
|
|
|
|
continue;
|
|
}
|
|
Poll::Pending => {
|
|
// Save a waker in case we want to reset the `Connecting` state while we are waiting.
|
|
self.waker = Some(cx.waker().clone());
|
|
return Poll::Pending;
|
|
}
|
|
},
|
|
};
|
|
|
|
// Priority 1: Keep local buffers small and send pending messages.
|
|
match stream.poll_ready_unpin(cx) {
|
|
Poll::Ready(Ok(())) => {
|
|
if let Some(message) = self.pending_messages.pop_front() {
|
|
match stream.start_send_unpin(Message::Text(message.clone())) {
|
|
Ok(()) => {
|
|
tracing::trace!(target: "wire::api::send", %message);
|
|
|
|
match stream.poll_flush_unpin(cx) {
|
|
Poll::Ready(Ok(())) => {
|
|
tracing::trace!("Flushed websocket");
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
self.reconnect_on_transient_error(
|
|
InternalError::WebSocket(e),
|
|
);
|
|
continue;
|
|
}
|
|
Poll::Pending => {}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
self.pending_messages.push_front(message);
|
|
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
|
continue;
|
|
}
|
|
Poll::Pending => {}
|
|
}
|
|
|
|
// Priority 2: Handle incoming messages.
|
|
match stream.poll_next_unpin(cx) {
|
|
Poll::Ready(Some(Ok(message))) => {
|
|
let Ok(message) = message.into_text() else {
|
|
tracing::warn!("Received non-text message from portal");
|
|
continue;
|
|
};
|
|
|
|
tracing::trace!(target: "wire::api::recv", %message);
|
|
|
|
let message = match serde_json::from_str::<
|
|
PhoenixMessage<TInboundMsg, TOutboundRes>,
|
|
>(&message)
|
|
{
|
|
Ok(m) => m,
|
|
Err(e) if e.is_io() || e.is_eof() => {
|
|
self.reconnect_on_transient_error(InternalError::Serde(e));
|
|
continue;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("Failed to deserialize message: {e}");
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match (message.payload, message.reference) {
|
|
(Payload::Message(msg), _) => {
|
|
return Poll::Ready(Ok(Event::InboundMessage {
|
|
topic: message.topic,
|
|
msg,
|
|
}))
|
|
}
|
|
(Payload::Reply(_), None) => {
|
|
tracing::warn!("Discarding reply because server omitted reference");
|
|
continue;
|
|
}
|
|
(Payload::Reply(Reply::Error { reason }), Some(req_id)) => {
|
|
if message.topic == self.login
|
|
&& self.pending_join_requests.contains(&req_id)
|
|
{
|
|
return Poll::Ready(Err(Error::LoginFailed(reason)));
|
|
}
|
|
|
|
return Poll::Ready(Ok(Event::ErrorResponse {
|
|
topic: message.topic,
|
|
req_id,
|
|
res: reason,
|
|
}));
|
|
}
|
|
(Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => {
|
|
if self.pending_join_requests.remove(&req_id) {
|
|
tracing::info!("Joined {} room on portal", message.topic);
|
|
|
|
// For `phx_join` requests, `reply` is empty so we can safely ignore it.
|
|
return Poll::Ready(Ok(Event::JoinedRoom {
|
|
topic: message.topic,
|
|
}));
|
|
}
|
|
|
|
return Poll::Ready(Ok(Event::SuccessResponse {
|
|
topic: message.topic,
|
|
req_id,
|
|
res: reply,
|
|
}));
|
|
}
|
|
(Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => {
|
|
if self.heartbeat.maybe_handle_reply(req_id.copy()) {
|
|
continue;
|
|
}
|
|
|
|
tracing::trace!("Received empty reply for request {req_id:?}");
|
|
|
|
continue;
|
|
}
|
|
(Payload::Error(Empty {}), reference) => {
|
|
tracing::debug!(
|
|
?reference,
|
|
topic = &message.topic,
|
|
"Received empty error response"
|
|
);
|
|
continue;
|
|
}
|
|
(Payload::Close(Empty {}), _) => {
|
|
self.reconnect_on_transient_error(InternalError::CloseMessage);
|
|
continue;
|
|
}
|
|
(
|
|
Payload::Disconnect {
|
|
reason: DisconnectReason::TokenExpired,
|
|
},
|
|
_,
|
|
) => {
|
|
return Poll::Ready(Err(Error::TokenExpired));
|
|
}
|
|
}
|
|
}
|
|
Poll::Ready(Some(Err(e))) => {
|
|
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
|
continue;
|
|
}
|
|
Poll::Ready(None) => {
|
|
self.reconnect_on_transient_error(InternalError::StreamClosed);
|
|
continue;
|
|
}
|
|
Poll::Pending => {}
|
|
}
|
|
|
|
// Priority 3: Handle heartbeats.
|
|
match self.heartbeat.poll(cx) {
|
|
Poll::Ready(Ok(id)) => {
|
|
self.pending_messages.push_back(serialize_msg(
|
|
"phoenix",
|
|
EgressControlMessage::<()>::Heartbeat(Empty {}),
|
|
id.copy(),
|
|
));
|
|
|
|
return Poll::Ready(Ok(Event::HeartbeatSent));
|
|
}
|
|
Poll::Ready(Err(MissedLastHeartbeat {})) => {
|
|
self.reconnect_on_transient_error(InternalError::MissedHeartbeat);
|
|
continue;
|
|
}
|
|
Poll::Pending => {}
|
|
}
|
|
|
|
return Poll::Pending;
|
|
}
|
|
}
|
|
|
|
/// Sets the channels state to [`State::Connecting`] with the given error.
|
|
///
|
|
/// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error.
|
|
fn reconnect_on_transient_error(&mut self, e: InternalError) {
|
|
self.state = State::Connecting(future::ready(Err(e)).boxed())
|
|
}
|
|
|
|
fn make_message(
|
|
&mut self,
|
|
topic: impl Into<String>,
|
|
payload: impl Serialize,
|
|
) -> (OutboundRequestId, String) {
|
|
let request_id = self.fetch_add_request_id();
|
|
|
|
// We don't care about the reply type when serializing
|
|
let msg = serialize_msg(topic, payload, request_id.copy());
|
|
|
|
(request_id, msg)
|
|
}
|
|
|
|
fn fetch_add_request_id(&mut self) -> OutboundRequestId {
|
|
let next_id = self
|
|
.next_request_id
|
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
|
|
OutboundRequestId(next_id)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum Event<TInboundMsg, TOutboundRes> {
|
|
SuccessResponse {
|
|
topic: String,
|
|
req_id: OutboundRequestId,
|
|
/// The response received for an outbound request.
|
|
res: TOutboundRes,
|
|
},
|
|
ErrorResponse {
|
|
topic: String,
|
|
req_id: OutboundRequestId,
|
|
res: ErrorReply,
|
|
},
|
|
JoinedRoom {
|
|
topic: String,
|
|
},
|
|
HeartbeatSent,
|
|
/// The server sent us a message, most likely this is a broadcast to all connected clients.
|
|
InboundMessage {
|
|
topic: String,
|
|
msg: TInboundMsg,
|
|
},
|
|
/// The connection was closed successfully.
|
|
Closed,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
|
|
pub struct PhoenixMessage<T, R> {
|
|
// TODO: we should use a newtype pattern for topics
|
|
topic: String,
|
|
#[serde(flatten)]
|
|
payload: Payload<T, R>,
|
|
#[serde(rename = "ref")]
|
|
reference: Option<OutboundRequestId>,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
|
#[serde(tag = "event", content = "payload")]
|
|
enum Payload<T, R> {
|
|
#[serde(rename = "phx_reply")]
|
|
Reply(Reply<R>),
|
|
#[serde(rename = "phx_error")]
|
|
Error(Empty),
|
|
#[serde(rename = "phx_close")]
|
|
Close(Empty),
|
|
#[serde(rename = "disconnect")]
|
|
Disconnect { reason: DisconnectReason },
|
|
#[serde(untagged)]
|
|
Message(T),
|
|
}
|
|
|
|
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
|
|
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
|
|
#[serde(deny_unknown_fields)]
|
|
struct Empty {}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
|
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
|
|
enum Reply<T> {
|
|
Ok(OkReply<T>),
|
|
Error { reason: ErrorReply },
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(untagged)]
|
|
enum OkReply<T> {
|
|
Message(T),
|
|
NoMessage(Empty),
|
|
}
|
|
|
|
// TODO: I think this should also be a type-parameter.
|
|
/// This represents the info we have about the error
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum ErrorReply {
|
|
#[serde(rename = "unmatched topic")]
|
|
UnmatchedTopic,
|
|
NotFound,
|
|
InvalidVersion,
|
|
Offline,
|
|
Disabled,
|
|
#[serde(other)]
|
|
Other,
|
|
}
|
|
|
|
impl fmt::Display for ErrorReply {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
ErrorReply::UnmatchedTopic => write!(f, "unmatched topic"),
|
|
ErrorReply::NotFound => write!(f, "not found"),
|
|
ErrorReply::InvalidVersion => write!(f, "invalid version"),
|
|
ErrorReply::Offline => write!(f, "offline"),
|
|
ErrorReply::Disabled => write!(f, "disabled"),
|
|
ErrorReply::Other => write!(f, "other"),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum DisconnectReason {
|
|
TokenExpired,
|
|
}
|
|
|
|
impl<T, R> PhoenixMessage<T, R> {
|
|
pub fn new_message(
|
|
topic: impl Into<String>,
|
|
payload: T,
|
|
reference: Option<OutboundRequestId>,
|
|
) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
payload: Payload::Message(payload),
|
|
reference,
|
|
}
|
|
}
|
|
|
|
pub fn new_ok_reply(
|
|
topic: impl Into<String>,
|
|
payload: R,
|
|
reference: Option<OutboundRequestId>,
|
|
) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))),
|
|
reference,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
fn new_err_reply(
|
|
topic: impl Into<String>,
|
|
reason: ErrorReply,
|
|
reference: Option<OutboundRequestId>,
|
|
) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
payload: Payload::Reply(Reply::Error { reason }),
|
|
reference,
|
|
}
|
|
}
|
|
}
|
|
|
|
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
|
|
fn make_request(url: Secret<LoginUrl>, user_agent: String) -> Request {
|
|
use secrecy::ExposeSecret as _;
|
|
|
|
let mut r = [0u8; 16];
|
|
OsRng.fill_bytes(&mut r);
|
|
let key = base64::engine::general_purpose::STANDARD.encode(r);
|
|
|
|
Request::builder()
|
|
.method("GET")
|
|
.header("Host", url.expose_secret().host())
|
|
.header("Connection", "Upgrade")
|
|
.header("Upgrade", "websocket")
|
|
.header("Sec-WebSocket-Version", "13")
|
|
.header("Sec-WebSocket-Key", key)
|
|
.header("User-Agent", user_agent)
|
|
.uri(url.expose_secret().inner().as_str())
|
|
.body(())
|
|
.expect("building static request always works")
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
|
enum EgressControlMessage<T> {
|
|
PhxJoin(T),
|
|
Heartbeat(Empty),
|
|
}
|
|
|
|
fn serialize_msg(
|
|
topic: impl Into<String>,
|
|
payload: impl Serialize,
|
|
request_id: OutboundRequestId,
|
|
) -> String {
|
|
serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
|
|
topic,
|
|
payload,
|
|
Some(request_id),
|
|
))
|
|
.expect("we should always be able to serialize a join topic message")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[derive(Deserialize, PartialEq, Debug)]
|
|
#[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work.
|
|
enum Msg {
|
|
Shout { hello: String },
|
|
}
|
|
|
|
#[test]
|
|
fn can_deserialize_inbound_message() {
|
|
let msg = r#"{
|
|
"topic": "room:lobby",
|
|
"ref": null,
|
|
"payload": {
|
|
"hello": "world"
|
|
},
|
|
"join_ref": null,
|
|
"event": "shout"
|
|
}"#;
|
|
|
|
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
|
|
|
|
assert_eq!(msg.topic, "room:lobby");
|
|
assert_eq!(msg.reference, None);
|
|
assert_eq!(
|
|
msg.payload,
|
|
Payload::Message(Msg::Shout {
|
|
hello: "world".to_owned()
|
|
})
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn unmatched_topic_reply() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "phx_reply",
|
|
"ref": "12",
|
|
"topic": "client",
|
|
"payload":{
|
|
"status": "error",
|
|
"response":{
|
|
"reason": "unmatched topic"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
|
reason: ErrorReply::UnmatchedTopic,
|
|
});
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn phx_close() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "phx_close",
|
|
"ref": null,
|
|
"topic": "client",
|
|
"payload": {}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Close(Empty {});
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn token_expired() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "disconnect",
|
|
"ref": null,
|
|
"topic": "client",
|
|
"payload": { "reason": "token_expired" }
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Disconnect {
|
|
reason: DisconnectReason::TokenExpired,
|
|
};
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn not_found() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "phx_reply",
|
|
"ref": null,
|
|
"topic": "client",
|
|
"payload": {
|
|
"status": "error",
|
|
"response": {
|
|
"reason": "not_found"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
|
reason: ErrorReply::NotFound,
|
|
});
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn unexpected_error_reply() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "phx_reply",
|
|
"ref": "12",
|
|
"topic": "client",
|
|
"payload": {
|
|
"status": "error",
|
|
"response": {
|
|
"reason": "bad reply"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
|
reason: ErrorReply::Other,
|
|
});
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn invalid_version_reply() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event": "phx_reply",
|
|
"ref": "12",
|
|
"topic": "client",
|
|
"payload":{
|
|
"status": "error",
|
|
"response":{
|
|
"reason": "invalid_version"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(Reply::Error {
|
|
reason: ErrorReply::InvalidVersion,
|
|
});
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn disabled_err_reply() {
|
|
let json = r#"{"event":"phx_reply","ref":null,"topic":"client","payload":{"status":"error","response":{"reason": "disabled"}}}"#;
|
|
|
|
let actual = serde_json::from_str::<PhoenixMessage<(), ()>>(json).unwrap();
|
|
let expected = PhoenixMessage::new_err_reply("client", ErrorReply::Disabled, None);
|
|
|
|
assert_eq!(actual, expected)
|
|
}
|
|
}
|