diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index dcb6392f0..7653e7f78 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -3,7 +3,7 @@ use crate::messages::{ EgressMessages, IngressMessages, RejectAccess, RequestConnection, }; use crate::CallbackHandler; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Result}; use boringtun::x25519::PublicKey; use connlib_shared::{ messages::{GatewayResponse, ResourceAccepted, ResourceDescription}, @@ -222,9 +222,6 @@ impl Eventloop { // TODO: Handle `init` message during operation. continue; } - Poll::Ready(phoenix_channel::Event::Disconnect(reason)) => { - return Poll::Ready(Err(anyhow!("Disconnected by portal: {reason}"))); - } _ => {} } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 910c01e90..83613c659 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -14,6 +14,7 @@ use secrecy::{CloneableSecret, ExposeSecret as _, Secret}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::task::{ready, Context, Poll}; use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite::http::StatusCode; use tokio_tungstenite::{ connect_async, tungstenite::{handshake::client::Request, Message}, @@ -45,7 +46,9 @@ pub struct PhoenixChannel { enum State { Connected(WebSocketStream>), - Connecting(BoxFuture<'static, Result>, Error>>), + Connecting( + BoxFuture<'static, Result>, InternalError>>, + ), } /// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message. @@ -113,21 +116,53 @@ pub struct UnexpectedEventDuringInit(String); #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("provided URI is missing a host")] - MissingHost, - #[error("websocket failed")] - WebSocket(#[from] tokio_tungstenite::tungstenite::Error), - #[error("failed to serialize message")] - Serde(#[from] serde_json::Error), - #[error("server sent a reply without a reference")] - MissingReplyId, - #[error("server did not reply to our heartbeat")] + #[error("client error: {0}")] + ClientError(StatusCode), + #[error("token expired")] + TokenExpired, + #[error("max retries reached")] + MaxRetriesReached, +} + +impl Error { + pub fn is_authentication_error(&self) -> bool { + match self { + Error::ClientError(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN, + Error::TokenExpired => true, + Error::MaxRetriesReached => false, + } + } +} + +enum InternalError { + WebSocket(tokio_tungstenite::tungstenite::Error), + Serde(serde_json::Error), MissedHeartbeat, - #[error("connection close message")] CloseMessage, } -#[derive(Debug, PartialEq, Eq, Hash)] +impl fmt::Display for InternalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => { + let status = http.status(); + let body = http + .body() + .as_deref() + .map(String::from_utf8_lossy) + .unwrap_or_default(); + + write!(f, "http error: {status} - {body}") + } + InternalError::WebSocket(e) => write!(f, "websocket connection failed: {e}"), + InternalError::Serde(e) => write!(f, "failed to deserialize message: {e}"), + InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"), + InternalError::CloseMessage => write!(f, "portal closed the websocket connection"), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize)] pub struct OutboundRequestId(u64); impl OutboundRequestId { @@ -135,6 +170,13 @@ impl OutboundRequestId { pub(crate) fn new(id: u64) -> Self { Self(id) } + + /// Internal function to make a copy. + /// + /// Not exposed publicly because these IDs are meant to be unique. + pub(crate) fn copy(&self) -> Self { + Self(self.0) + } } impl fmt::Display for OutboundRequestId { @@ -168,6 +210,10 @@ impl SecureUrl { pub fn host(&self) -> Option<&str> { self.inner.host_str() } + + pub fn port(&self) -> Option { + self.inner.port() + } } impl CloneableSecret for SecureUrl {} @@ -203,7 +249,9 @@ where secret_url: secret_url.clone(), user_agent: user_agent.clone(), state: State::Connecting(Box::pin(async move { - let (stream, _) = connect_async(make_request(secret_url, user_agent)?).await?; + let (stream, _) = connect_async(make_request(secret_url, user_agent)) + .await + .map_err(InternalError::WebSocket)?; Ok(stream) })), @@ -255,41 +303,28 @@ where continue; } + Err(InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http( + r, + ))) if r.status().is_client_error() => { + return Poll::Ready(Err(Error::ClientError(r.status()))); + } Err(e) => { - if let Error::WebSocket(tokio_tungstenite::tungstenite::Error::Http(r)) = &e - { - let status = r.status(); - - if status.is_client_error() { - let body = r - .body() - .as_deref() - .map(String::from_utf8_lossy) - .unwrap_or_default(); - - tracing::warn!( - "Fatal client error ({status}) in portal connection: {body}" - ); - - return Poll::Ready(Err(e)); - } - }; - let Some(backoff) = self.reconnect_backoff.next_backoff() else { tracing::warn!("Reconnect backoff expired"); - return Poll::Ready(Err(e)); + return Poll::Ready(Err(Error::MaxRetriesReached)); }; let secret_url = self.secret_url.clone(); let user_agent = self.user_agent.clone(); - tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e)); + tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}"); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; - let (stream, _) = - connect_async(make_request(secret_url, user_agent)?).await?; + let (stream, _) = connect_async(make_request(secret_url, user_agent)) + .await + .map_err(InternalError::WebSocket)?; Ok(stream) })); @@ -306,7 +341,7 @@ where match stream.start_send_unpin(Message::Text(message)) { Ok(()) => {} Err(e) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); } } continue; @@ -329,7 +364,7 @@ where { Ok(m) => m, Err(e) if e.is_io() || e.is_eof() => { - self.reconnect_on_transient_error(Error::Serde(e)); + self.reconnect_on_transient_error(InternalError::Serde(e)); continue; } Err(e) => { @@ -338,26 +373,25 @@ where } }; - match message.payload { - Payload::Message(msg) => { + match (message.payload, message.reference) { + (Payload::Message(msg), _) => { return Poll::Ready(Ok(Event::InboundMessage { topic: message.topic, msg, })) } - Payload::Reply(Reply::Error { reason }) => { + (Payload::Reply(_), None) => { + tracing::warn!("Discarding reply because server omitted reference"); + continue; + } + (Payload::Reply(Reply::Error { reason }), Some(req_id)) => { return Poll::Ready(Ok(Event::ErrorResponse { topic: message.topic, - req_id: OutboundRequestId( - message.reference.ok_or(Error::MissingReplyId)?, - ), + req_id, reason, })); } - Payload::Reply(Reply::Ok(OkReply::Message(reply))) => { - let req_id = - OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); - + (Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => { if self.pending_join_requests.remove(&req_id) { tracing::info!("Joined {} room on portal", message.topic); @@ -373,41 +407,39 @@ where res: reply, })); } - Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => { - let id = - OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); - - if self.heartbeat.maybe_handle_reply(id) { + (Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => { + if self.heartbeat.maybe_handle_reply(req_id.copy()) { continue; } - tracing::trace!( - "Received empty reply for request {:?}", - message.reference - ); + tracing::trace!("Received empty reply for request {req_id:?}"); continue; } - Payload::Error(Empty {}) => { - return Poll::Ready(Ok(Event::ErrorResponse { - topic: message.topic, - req_id: OutboundRequestId( - message.reference.ok_or(Error::MissingReplyId)?, - ), - reason: ErrorReply::Other, - })) - } - Payload::Close(Empty {}) => { - self.reconnect_on_transient_error(Error::CloseMessage); + (Payload::Error(Empty {}), reference) => { + tracing::debug!( + ?reference, + topic = &message.topic, + "Received empty error response" + ); continue; } - Payload::Disconnect { reason } => { - return Poll::Ready(Ok(Event::Disconnect(reason))); + (Payload::Close(Empty {}), _) => { + self.reconnect_on_transient_error(InternalError::CloseMessage); + continue; + } + ( + Payload::Disconnect { + reason: DisconnectReason::TokenExpired, + }, + _, + ) => { + return Poll::Ready(Err(Error::TokenExpired)); } } } Poll::Ready(Some(Err(e))) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); continue; } _ => (), @@ -422,7 +454,7 @@ where return Poll::Ready(Ok(Event::HeartbeatSent)); } Poll::Ready(Err(MissedLastHeartbeat {})) => { - self.reconnect_on_transient_error(Error::MissedHeartbeat); + self.reconnect_on_transient_error(InternalError::MissedHeartbeat); continue; } _ => (), @@ -434,7 +466,7 @@ where tracing::trace!("Flushed websocket"); } Poll::Ready(Err(e)) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); continue; } Poll::Pending => {} @@ -447,7 +479,7 @@ where /// Sets the channels state to [`State::Connecting`] with the given error. /// /// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error. - fn reconnect_on_transient_error(&mut self, e: Error) { + fn reconnect_on_transient_error(&mut self, e: InternalError) { self.state = State::Connecting(future::ready(Err(e)).boxed()) } @@ -459,19 +491,23 @@ where let request_id = self.fetch_add_request_id(); // We don't care about the reply type when serializing - let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id)) - .expect("we should always be able to serialize a join topic message"); + let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new( + topic, + payload, + request_id.copy(), + )) + .expect("we should always be able to serialize a join topic message"); self.pending_messages.push_back(msg); - OutboundRequestId(request_id) + request_id } - fn fetch_add_request_id(&mut self) -> u64 { + fn fetch_add_request_id(&mut self) -> OutboundRequestId { let next_id = self.next_request_id; self.next_request_id += 1; - next_id + OutboundRequestId(next_id) } /// Cast this instance of [PhoenixChannel] to new message types. @@ -516,17 +552,16 @@ pub enum Event { topic: String, msg: TInboundMsg, }, - Disconnect(String), } -#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics topic: String, #[serde(flatten)] payload: Payload, #[serde(rename = "ref")] - reference: Option, + reference: Option, } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] @@ -539,7 +574,7 @@ enum Payload { #[serde(rename = "phx_close")] Close(Empty), #[serde(rename = "disconnect")] - Disconnect { reason: String }, + Disconnect { reason: DisconnectReason }, #[serde(untagged)] Message(T), } @@ -569,7 +604,6 @@ enum OkReply { pub enum ErrorReply { #[serde(rename = "unmatched topic")] UnmatchedTopic, - TokenExpired, NotFound, Offline, Disabled, @@ -577,8 +611,14 @@ pub enum ErrorReply { Other, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum DisconnectReason { + TokenExpired, +} + impl PhoenixMessage { - pub fn new(topic: impl Into, payload: T, reference: u64) -> Self { + pub fn new(topic: impl Into, payload: T, reference: OutboundRequestId) -> Self { Self { topic: topic.into(), payload: Payload::Message(payload), @@ -588,37 +628,35 @@ impl PhoenixMessage { } // This is basically the same as tungstenite does but we add some new headers (namely user-agent) -fn make_request(secret_url: Secret, user_agent: String) -> Result { +fn make_request(secret_url: Secret, user_agent: String) -> Request { use secrecy::ExposeSecret; - let host = secret_url - .expose_secret() - .inner - .host() - .ok_or(Error::MissingHost)?; - let host = if let Some(port) = secret_url.expose_secret().inner.port() { - format!("{host}:{port}") - } else { - host.to_string() - }; - let mut r = [0u8; 16]; OsRng.fill_bytes(&mut r); let key = base64::engine::general_purpose::STANDARD.encode(r); - let req = Request::builder() + let mut req_builder = Request::builder() .method("GET") - .header("Host", host) .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", key) .header("User-Agent", user_agent) - .uri(secret_url.expose_secret().inner.as_str()) - .body(()) - .expect("building static request always works"); + .uri(secret_url.expose_secret().inner.as_str()); - Ok(req) + if let Some(host) = secret_url.expose_secret().host() { + let host = secret_url + .expose_secret() + .port() + .map(|port| format!("{host}:{port}")) + .unwrap_or(host.to_string()); + + req_builder = req_builder.header("Host", host); + } + + req_builder + .body(()) + .expect("building static request always works") } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -727,7 +765,7 @@ mod tests { "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); let expected_reply = Payload::<(), ()>::Disconnect { - reason: "token_expired".to_string(), + reason: DisconnectReason::TokenExpired, }; assert_eq!(actual_reply, expected_reply); } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 3193f33b5..52e0392cf 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -9,7 +9,7 @@ use futures::channel::mpsc; use futures::{future, FutureExt, SinkExt, StreamExt}; use opentelemetry::{sdk, KeyValue}; use opentelemetry_otlp::WithExportConfig; -use phoenix_channel::{Error, Event, PhoenixChannel, SecureUrl}; +use phoenix_channel::{Event, PhoenixChannel, SecureUrl}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; @@ -501,13 +501,6 @@ where // Priority 5: Handle portal messages match self.channel.as_mut().map(|c| c.poll(cx)) { - Some(Poll::Ready(Ok(Event::Disconnect(reason)))) => { - return Poll::Ready(Err(anyhow!("Connection closed by portal: {reason}"))); - } - Some(Poll::Ready(Err(Error::Serde(e)))) => { - tracing::warn!(target: "relay", "Failed to deserialize portal message: {e}"); - continue; // This is not a hard-error, we can continue. - } Some(Poll::Ready(Err(e))) => { return Poll::Ready(Err(anyhow!("Portal connection failed: {e}"))); }