From 4339030d0370050e20272f6a5bdf4da54d475d54 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 9 Mar 2024 19:03:25 +1100 Subject: [PATCH] refactor(phoenix-channel): reduce `Error` to fatal errors (#4015) As part of doing https://github.com/firezone/firezone/pull/3682, we noticed that the handling of errors up to the clients needs to differentiate between fatal errors that require clearing the token vs not. Upon closer inspection of `phoenix_channel::Error`, it becomes obvious that the current design is not good here. In particular, we handle certain errors with retries internally but still expose those same errors. To make this more obvious, we reduce the public `Error` to the variants that are actually fatal. Those can really only be three: - HTTP client errors (those are by definition non-retryable) - Token expired - We have reached our max number of retries --- rust/gateway/src/eventloop.rs | 5 +- rust/phoenix-channel/src/lib.rs | 246 ++++++++++++++++++-------------- rust/relay/src/main.rs | 9 +- 3 files changed, 144 insertions(+), 116 deletions(-) diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index dcb6392f0..7653e7f78 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -3,7 +3,7 @@ use crate::messages::{ EgressMessages, IngressMessages, RejectAccess, RequestConnection, }; use crate::CallbackHandler; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Result}; use boringtun::x25519::PublicKey; use connlib_shared::{ messages::{GatewayResponse, ResourceAccepted, ResourceDescription}, @@ -222,9 +222,6 @@ impl Eventloop { // TODO: Handle `init` message during operation. continue; } - Poll::Ready(phoenix_channel::Event::Disconnect(reason)) => { - return Poll::Ready(Err(anyhow!("Disconnected by portal: {reason}"))); - } _ => {} } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 910c01e90..83613c659 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -14,6 +14,7 @@ use secrecy::{CloneableSecret, ExposeSecret as _, Secret}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::task::{ready, Context, Poll}; use tokio::net::TcpStream; +use tokio_tungstenite::tungstenite::http::StatusCode; use tokio_tungstenite::{ connect_async, tungstenite::{handshake::client::Request, Message}, @@ -45,7 +46,9 @@ pub struct PhoenixChannel { enum State { Connected(WebSocketStream>), - Connecting(BoxFuture<'static, Result>, Error>>), + Connecting( + BoxFuture<'static, Result>, InternalError>>, + ), } /// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message. @@ -113,21 +116,53 @@ pub struct UnexpectedEventDuringInit(String); #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("provided URI is missing a host")] - MissingHost, - #[error("websocket failed")] - WebSocket(#[from] tokio_tungstenite::tungstenite::Error), - #[error("failed to serialize message")] - Serde(#[from] serde_json::Error), - #[error("server sent a reply without a reference")] - MissingReplyId, - #[error("server did not reply to our heartbeat")] + #[error("client error: {0}")] + ClientError(StatusCode), + #[error("token expired")] + TokenExpired, + #[error("max retries reached")] + MaxRetriesReached, +} + +impl Error { + pub fn is_authentication_error(&self) -> bool { + match self { + Error::ClientError(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN, + Error::TokenExpired => true, + Error::MaxRetriesReached => false, + } + } +} + +enum InternalError { + WebSocket(tokio_tungstenite::tungstenite::Error), + Serde(serde_json::Error), MissedHeartbeat, - #[error("connection close message")] CloseMessage, } -#[derive(Debug, PartialEq, Eq, Hash)] +impl fmt::Display for InternalError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => { + let status = http.status(); + let body = http + .body() + .as_deref() + .map(String::from_utf8_lossy) + .unwrap_or_default(); + + write!(f, "http error: {status} - {body}") + } + InternalError::WebSocket(e) => write!(f, "websocket connection failed: {e}"), + InternalError::Serde(e) => write!(f, "failed to deserialize message: {e}"), + InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"), + InternalError::CloseMessage => write!(f, "portal closed the websocket connection"), + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize)] pub struct OutboundRequestId(u64); impl OutboundRequestId { @@ -135,6 +170,13 @@ impl OutboundRequestId { pub(crate) fn new(id: u64) -> Self { Self(id) } + + /// Internal function to make a copy. + /// + /// Not exposed publicly because these IDs are meant to be unique. + pub(crate) fn copy(&self) -> Self { + Self(self.0) + } } impl fmt::Display for OutboundRequestId { @@ -168,6 +210,10 @@ impl SecureUrl { pub fn host(&self) -> Option<&str> { self.inner.host_str() } + + pub fn port(&self) -> Option { + self.inner.port() + } } impl CloneableSecret for SecureUrl {} @@ -203,7 +249,9 @@ where secret_url: secret_url.clone(), user_agent: user_agent.clone(), state: State::Connecting(Box::pin(async move { - let (stream, _) = connect_async(make_request(secret_url, user_agent)?).await?; + let (stream, _) = connect_async(make_request(secret_url, user_agent)) + .await + .map_err(InternalError::WebSocket)?; Ok(stream) })), @@ -255,41 +303,28 @@ where continue; } + Err(InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http( + r, + ))) if r.status().is_client_error() => { + return Poll::Ready(Err(Error::ClientError(r.status()))); + } Err(e) => { - if let Error::WebSocket(tokio_tungstenite::tungstenite::Error::Http(r)) = &e - { - let status = r.status(); - - if status.is_client_error() { - let body = r - .body() - .as_deref() - .map(String::from_utf8_lossy) - .unwrap_or_default(); - - tracing::warn!( - "Fatal client error ({status}) in portal connection: {body}" - ); - - return Poll::Ready(Err(e)); - } - }; - let Some(backoff) = self.reconnect_backoff.next_backoff() else { tracing::warn!("Reconnect backoff expired"); - return Poll::Ready(Err(e)); + return Poll::Ready(Err(Error::MaxRetriesReached)); }; let secret_url = self.secret_url.clone(); let user_agent = self.user_agent.clone(); - tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e)); + tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}"); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; - let (stream, _) = - connect_async(make_request(secret_url, user_agent)?).await?; + let (stream, _) = connect_async(make_request(secret_url, user_agent)) + .await + .map_err(InternalError::WebSocket)?; Ok(stream) })); @@ -306,7 +341,7 @@ where match stream.start_send_unpin(Message::Text(message)) { Ok(()) => {} Err(e) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); } } continue; @@ -329,7 +364,7 @@ where { Ok(m) => m, Err(e) if e.is_io() || e.is_eof() => { - self.reconnect_on_transient_error(Error::Serde(e)); + self.reconnect_on_transient_error(InternalError::Serde(e)); continue; } Err(e) => { @@ -338,26 +373,25 @@ where } }; - match message.payload { - Payload::Message(msg) => { + match (message.payload, message.reference) { + (Payload::Message(msg), _) => { return Poll::Ready(Ok(Event::InboundMessage { topic: message.topic, msg, })) } - Payload::Reply(Reply::Error { reason }) => { + (Payload::Reply(_), None) => { + tracing::warn!("Discarding reply because server omitted reference"); + continue; + } + (Payload::Reply(Reply::Error { reason }), Some(req_id)) => { return Poll::Ready(Ok(Event::ErrorResponse { topic: message.topic, - req_id: OutboundRequestId( - message.reference.ok_or(Error::MissingReplyId)?, - ), + req_id, reason, })); } - Payload::Reply(Reply::Ok(OkReply::Message(reply))) => { - let req_id = - OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); - + (Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => { if self.pending_join_requests.remove(&req_id) { tracing::info!("Joined {} room on portal", message.topic); @@ -373,41 +407,39 @@ where res: reply, })); } - Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => { - let id = - OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); - - if self.heartbeat.maybe_handle_reply(id) { + (Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => { + if self.heartbeat.maybe_handle_reply(req_id.copy()) { continue; } - tracing::trace!( - "Received empty reply for request {:?}", - message.reference - ); + tracing::trace!("Received empty reply for request {req_id:?}"); continue; } - Payload::Error(Empty {}) => { - return Poll::Ready(Ok(Event::ErrorResponse { - topic: message.topic, - req_id: OutboundRequestId( - message.reference.ok_or(Error::MissingReplyId)?, - ), - reason: ErrorReply::Other, - })) - } - Payload::Close(Empty {}) => { - self.reconnect_on_transient_error(Error::CloseMessage); + (Payload::Error(Empty {}), reference) => { + tracing::debug!( + ?reference, + topic = &message.topic, + "Received empty error response" + ); continue; } - Payload::Disconnect { reason } => { - return Poll::Ready(Ok(Event::Disconnect(reason))); + (Payload::Close(Empty {}), _) => { + self.reconnect_on_transient_error(InternalError::CloseMessage); + continue; + } + ( + Payload::Disconnect { + reason: DisconnectReason::TokenExpired, + }, + _, + ) => { + return Poll::Ready(Err(Error::TokenExpired)); } } } Poll::Ready(Some(Err(e))) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); continue; } _ => (), @@ -422,7 +454,7 @@ where return Poll::Ready(Ok(Event::HeartbeatSent)); } Poll::Ready(Err(MissedLastHeartbeat {})) => { - self.reconnect_on_transient_error(Error::MissedHeartbeat); + self.reconnect_on_transient_error(InternalError::MissedHeartbeat); continue; } _ => (), @@ -434,7 +466,7 @@ where tracing::trace!("Flushed websocket"); } Poll::Ready(Err(e)) => { - self.reconnect_on_transient_error(Error::WebSocket(e)); + self.reconnect_on_transient_error(InternalError::WebSocket(e)); continue; } Poll::Pending => {} @@ -447,7 +479,7 @@ where /// Sets the channels state to [`State::Connecting`] with the given error. /// /// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error. - fn reconnect_on_transient_error(&mut self, e: Error) { + fn reconnect_on_transient_error(&mut self, e: InternalError) { self.state = State::Connecting(future::ready(Err(e)).boxed()) } @@ -459,19 +491,23 @@ where let request_id = self.fetch_add_request_id(); // We don't care about the reply type when serializing - let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id)) - .expect("we should always be able to serialize a join topic message"); + let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new( + topic, + payload, + request_id.copy(), + )) + .expect("we should always be able to serialize a join topic message"); self.pending_messages.push_back(msg); - OutboundRequestId(request_id) + request_id } - fn fetch_add_request_id(&mut self) -> u64 { + fn fetch_add_request_id(&mut self) -> OutboundRequestId { let next_id = self.next_request_id; self.next_request_id += 1; - next_id + OutboundRequestId(next_id) } /// Cast this instance of [PhoenixChannel] to new message types. @@ -516,17 +552,16 @@ pub enum Event { topic: String, msg: TInboundMsg, }, - Disconnect(String), } -#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] pub struct PhoenixMessage { // TODO: we should use a newtype pattern for topics topic: String, #[serde(flatten)] payload: Payload, #[serde(rename = "ref")] - reference: Option, + reference: Option, } #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] @@ -539,7 +574,7 @@ enum Payload { #[serde(rename = "phx_close")] Close(Empty), #[serde(rename = "disconnect")] - Disconnect { reason: String }, + Disconnect { reason: DisconnectReason }, #[serde(untagged)] Message(T), } @@ -569,7 +604,6 @@ enum OkReply { pub enum ErrorReply { #[serde(rename = "unmatched topic")] UnmatchedTopic, - TokenExpired, NotFound, Offline, Disabled, @@ -577,8 +611,14 @@ pub enum ErrorReply { Other, } +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum DisconnectReason { + TokenExpired, +} + impl PhoenixMessage { - pub fn new(topic: impl Into, payload: T, reference: u64) -> Self { + pub fn new(topic: impl Into, payload: T, reference: OutboundRequestId) -> Self { Self { topic: topic.into(), payload: Payload::Message(payload), @@ -588,37 +628,35 @@ impl PhoenixMessage { } // This is basically the same as tungstenite does but we add some new headers (namely user-agent) -fn make_request(secret_url: Secret, user_agent: String) -> Result { +fn make_request(secret_url: Secret, user_agent: String) -> Request { use secrecy::ExposeSecret; - let host = secret_url - .expose_secret() - .inner - .host() - .ok_or(Error::MissingHost)?; - let host = if let Some(port) = secret_url.expose_secret().inner.port() { - format!("{host}:{port}") - } else { - host.to_string() - }; - let mut r = [0u8; 16]; OsRng.fill_bytes(&mut r); let key = base64::engine::general_purpose::STANDARD.encode(r); - let req = Request::builder() + let mut req_builder = Request::builder() .method("GET") - .header("Host", host) .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", key) .header("User-Agent", user_agent) - .uri(secret_url.expose_secret().inner.as_str()) - .body(()) - .expect("building static request always works"); + .uri(secret_url.expose_secret().inner.as_str()); - Ok(req) + if let Some(host) = secret_url.expose_secret().host() { + let host = secret_url + .expose_secret() + .port() + .map(|port| format!("{host}:{port}")) + .unwrap_or(host.to_string()); + + req_builder = req_builder.header("Host", host); + } + + req_builder + .body(()) + .expect("building static request always works") } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -727,7 +765,7 @@ mod tests { "#; let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); let expected_reply = Payload::<(), ()>::Disconnect { - reason: "token_expired".to_string(), + reason: DisconnectReason::TokenExpired, }; assert_eq!(actual_reply, expected_reply); } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 3193f33b5..52e0392cf 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -9,7 +9,7 @@ use futures::channel::mpsc; use futures::{future, FutureExt, SinkExt, StreamExt}; use opentelemetry::{sdk, KeyValue}; use opentelemetry_otlp::WithExportConfig; -use phoenix_channel::{Error, Event, PhoenixChannel, SecureUrl}; +use phoenix_channel::{Event, PhoenixChannel, SecureUrl}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; @@ -501,13 +501,6 @@ where // Priority 5: Handle portal messages match self.channel.as_mut().map(|c| c.poll(cx)) { - Some(Poll::Ready(Ok(Event::Disconnect(reason)))) => { - return Poll::Ready(Err(anyhow!("Connection closed by portal: {reason}"))); - } - Some(Poll::Ready(Err(Error::Serde(e)))) => { - tracing::warn!(target: "relay", "Failed to deserialize portal message: {e}"); - continue; // This is not a hard-error, we can continue. - } Some(Poll::Ready(Err(e))) => { return Poll::Ready(Err(anyhow!("Portal connection failed: {e}"))); }