refactor(phoenix-channel): reduce Error to fatal errors (#4015)

As part of doing https://github.com/firezone/firezone/pull/3682, we
noticed that the handling of errors up to the clients needs to
differentiate between fatal errors that require clearing the token vs
not.

Upon closer inspection of `phoenix_channel::Error`, it becomes obvious
that the current design is not good here. In particular, we handle
certain errors with retries internally but still expose those same
errors.

To make this more obvious, we reduce the public `Error` to the variants
that are actually fatal. Those can really only be three:

- HTTP client errors (those are by definition non-retryable)
- Token expired
- We have reached our max number of retries
This commit is contained in:
Thomas Eizinger
2024-03-09 19:03:25 +11:00
committed by GitHub
parent 21fe85048c
commit 4339030d03
3 changed files with 144 additions and 116 deletions

View File

@@ -3,7 +3,7 @@ use crate::messages::{
EgressMessages, IngressMessages, RejectAccess, RequestConnection,
};
use crate::CallbackHandler;
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Result};
use boringtun::x25519::PublicKey;
use connlib_shared::{
messages::{GatewayResponse, ResourceAccepted, ResourceDescription},
@@ -222,9 +222,6 @@ impl Eventloop {
// TODO: Handle `init` message during operation.
continue;
}
Poll::Ready(phoenix_channel::Event::Disconnect(reason)) => {
return Poll::Ready(Err(anyhow!("Disconnected by portal: {reason}")));
}
_ => {}
}

View File

@@ -14,6 +14,7 @@ use secrecy::{CloneableSecret, ExposeSecret as _, Secret};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::task::{ready, Context, Poll};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::http::StatusCode;
use tokio_tungstenite::{
connect_async,
tungstenite::{handshake::client::Request, Message},
@@ -45,7 +46,9 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
enum State {
Connected(WebSocketStream<MaybeTlsStream<TcpStream>>),
Connecting(BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error>>),
Connecting(
BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError>>,
),
}
/// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message.
@@ -113,21 +116,53 @@ pub struct UnexpectedEventDuringInit(String);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("provided URI is missing a host")]
MissingHost,
#[error("websocket failed")]
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("failed to serialize message")]
Serde(#[from] serde_json::Error),
#[error("server sent a reply without a reference")]
MissingReplyId,
#[error("server did not reply to our heartbeat")]
#[error("client error: {0}")]
ClientError(StatusCode),
#[error("token expired")]
TokenExpired,
#[error("max retries reached")]
MaxRetriesReached,
}
impl Error {
pub fn is_authentication_error(&self) -> bool {
match self {
Error::ClientError(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN,
Error::TokenExpired => true,
Error::MaxRetriesReached => false,
}
}
}
enum InternalError {
WebSocket(tokio_tungstenite::tungstenite::Error),
Serde(serde_json::Error),
MissedHeartbeat,
#[error("connection close message")]
CloseMessage,
}
#[derive(Debug, PartialEq, Eq, Hash)]
impl fmt::Display for InternalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => {
let status = http.status();
let body = http
.body()
.as_deref()
.map(String::from_utf8_lossy)
.unwrap_or_default();
write!(f, "http error: {status} - {body}")
}
InternalError::WebSocket(e) => write!(f, "websocket connection failed: {e}"),
InternalError::Serde(e) => write!(f, "failed to deserialize message: {e}"),
InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"),
InternalError::CloseMessage => write!(f, "portal closed the websocket connection"),
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize)]
pub struct OutboundRequestId(u64);
impl OutboundRequestId {
@@ -135,6 +170,13 @@ impl OutboundRequestId {
pub(crate) fn new(id: u64) -> Self {
Self(id)
}
/// Internal function to make a copy.
///
/// Not exposed publicly because these IDs are meant to be unique.
pub(crate) fn copy(&self) -> Self {
Self(self.0)
}
}
impl fmt::Display for OutboundRequestId {
@@ -168,6 +210,10 @@ impl SecureUrl {
pub fn host(&self) -> Option<&str> {
self.inner.host_str()
}
pub fn port(&self) -> Option<u16> {
self.inner.port()
}
}
impl CloneableSecret for SecureUrl {}
@@ -203,7 +249,9 @@ where
secret_url: secret_url.clone(),
user_agent: user_agent.clone(),
state: State::Connecting(Box::pin(async move {
let (stream, _) = connect_async(make_request(secret_url, user_agent)?).await?;
let (stream, _) = connect_async(make_request(secret_url, user_agent))
.await
.map_err(InternalError::WebSocket)?;
Ok(stream)
})),
@@ -255,41 +303,28 @@ where
continue;
}
Err(InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(
r,
))) if r.status().is_client_error() => {
return Poll::Ready(Err(Error::ClientError(r.status())));
}
Err(e) => {
if let Error::WebSocket(tokio_tungstenite::tungstenite::Error::Http(r)) = &e
{
let status = r.status();
if status.is_client_error() {
let body = r
.body()
.as_deref()
.map(String::from_utf8_lossy)
.unwrap_or_default();
tracing::warn!(
"Fatal client error ({status}) in portal connection: {body}"
);
return Poll::Ready(Err(e));
}
};
let Some(backoff) = self.reconnect_backoff.next_backoff() else {
tracing::warn!("Reconnect backoff expired");
return Poll::Ready(Err(e));
return Poll::Ready(Err(Error::MaxRetriesReached));
};
let secret_url = self.secret_url.clone();
let user_agent = self.user_agent.clone();
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e));
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
let (stream, _) =
connect_async(make_request(secret_url, user_agent)?).await?;
let (stream, _) = connect_async(make_request(secret_url, user_agent))
.await
.map_err(InternalError::WebSocket)?;
Ok(stream)
}));
@@ -306,7 +341,7 @@ where
match stream.start_send_unpin(Message::Text(message)) {
Ok(()) => {}
Err(e) => {
self.reconnect_on_transient_error(Error::WebSocket(e));
self.reconnect_on_transient_error(InternalError::WebSocket(e));
}
}
continue;
@@ -329,7 +364,7 @@ where
{
Ok(m) => m,
Err(e) if e.is_io() || e.is_eof() => {
self.reconnect_on_transient_error(Error::Serde(e));
self.reconnect_on_transient_error(InternalError::Serde(e));
continue;
}
Err(e) => {
@@ -338,26 +373,25 @@ where
}
};
match message.payload {
Payload::Message(msg) => {
match (message.payload, message.reference) {
(Payload::Message(msg), _) => {
return Poll::Ready(Ok(Event::InboundMessage {
topic: message.topic,
msg,
}))
}
Payload::Reply(Reply::Error { reason }) => {
(Payload::Reply(_), None) => {
tracing::warn!("Discarding reply because server omitted reference");
continue;
}
(Payload::Reply(Reply::Error { reason }), Some(req_id)) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
message.reference.ok_or(Error::MissingReplyId)?,
),
req_id,
reason,
}));
}
Payload::Reply(Reply::Ok(OkReply::Message(reply))) => {
let req_id =
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
(Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => {
if self.pending_join_requests.remove(&req_id) {
tracing::info!("Joined {} room on portal", message.topic);
@@ -373,41 +407,39 @@ where
res: reply,
}));
}
Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => {
let id =
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
if self.heartbeat.maybe_handle_reply(id) {
(Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => {
if self.heartbeat.maybe_handle_reply(req_id.copy()) {
continue;
}
tracing::trace!(
"Received empty reply for request {:?}",
message.reference
);
tracing::trace!("Received empty reply for request {req_id:?}");
continue;
}
Payload::Error(Empty {}) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
message.reference.ok_or(Error::MissingReplyId)?,
),
reason: ErrorReply::Other,
}))
}
Payload::Close(Empty {}) => {
self.reconnect_on_transient_error(Error::CloseMessage);
(Payload::Error(Empty {}), reference) => {
tracing::debug!(
?reference,
topic = &message.topic,
"Received empty error response"
);
continue;
}
Payload::Disconnect { reason } => {
return Poll::Ready(Ok(Event::Disconnect(reason)));
(Payload::Close(Empty {}), _) => {
self.reconnect_on_transient_error(InternalError::CloseMessage);
continue;
}
(
Payload::Disconnect {
reason: DisconnectReason::TokenExpired,
},
_,
) => {
return Poll::Ready(Err(Error::TokenExpired));
}
}
}
Poll::Ready(Some(Err(e))) => {
self.reconnect_on_transient_error(Error::WebSocket(e));
self.reconnect_on_transient_error(InternalError::WebSocket(e));
continue;
}
_ => (),
@@ -422,7 +454,7 @@ where
return Poll::Ready(Ok(Event::HeartbeatSent));
}
Poll::Ready(Err(MissedLastHeartbeat {})) => {
self.reconnect_on_transient_error(Error::MissedHeartbeat);
self.reconnect_on_transient_error(InternalError::MissedHeartbeat);
continue;
}
_ => (),
@@ -434,7 +466,7 @@ where
tracing::trace!("Flushed websocket");
}
Poll::Ready(Err(e)) => {
self.reconnect_on_transient_error(Error::WebSocket(e));
self.reconnect_on_transient_error(InternalError::WebSocket(e));
continue;
}
Poll::Pending => {}
@@ -447,7 +479,7 @@ where
/// Sets the channels state to [`State::Connecting`] with the given error.
///
/// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error.
fn reconnect_on_transient_error(&mut self, e: Error) {
fn reconnect_on_transient_error(&mut self, e: InternalError) {
self.state = State::Connecting(future::ready(Err(e)).boxed())
}
@@ -459,19 +491,23 @@ where
let request_id = self.fetch_add_request_id();
// We don't care about the reply type when serializing
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id))
.expect("we should always be able to serialize a join topic message");
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(
topic,
payload,
request_id.copy(),
))
.expect("we should always be able to serialize a join topic message");
self.pending_messages.push_back(msg);
OutboundRequestId(request_id)
request_id
}
fn fetch_add_request_id(&mut self) -> u64 {
fn fetch_add_request_id(&mut self) -> OutboundRequestId {
let next_id = self.next_request_id;
self.next_request_id += 1;
next_id
OutboundRequestId(next_id)
}
/// Cast this instance of [PhoenixChannel] to new message types.
@@ -516,17 +552,16 @@ pub enum Event<TInboundMsg, TOutboundRes> {
topic: String,
msg: TInboundMsg,
},
Disconnect(String),
}
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
// TODO: we should use a newtype pattern for topics
topic: String,
#[serde(flatten)]
payload: Payload<T, R>,
#[serde(rename = "ref")]
reference: Option<u64>,
reference: Option<OutboundRequestId>,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
@@ -539,7 +574,7 @@ enum Payload<T, R> {
#[serde(rename = "phx_close")]
Close(Empty),
#[serde(rename = "disconnect")]
Disconnect { reason: String },
Disconnect { reason: DisconnectReason },
#[serde(untagged)]
Message(T),
}
@@ -569,7 +604,6 @@ enum OkReply<T> {
pub enum ErrorReply {
#[serde(rename = "unmatched topic")]
UnmatchedTopic,
TokenExpired,
NotFound,
Offline,
Disabled,
@@ -577,8 +611,14 @@ pub enum ErrorReply {
Other,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DisconnectReason {
TokenExpired,
}
impl<T, R> PhoenixMessage<T, R> {
pub fn new(topic: impl Into<String>, payload: T, reference: u64) -> Self {
pub fn new(topic: impl Into<String>, payload: T, reference: OutboundRequestId) -> Self {
Self {
topic: topic.into(),
payload: Payload::Message(payload),
@@ -588,37 +628,35 @@ impl<T, R> PhoenixMessage<T, R> {
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Result<Request, Error> {
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Request {
use secrecy::ExposeSecret;
let host = secret_url
.expose_secret()
.inner
.host()
.ok_or(Error::MissingHost)?;
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
format!("{host}:{port}")
} else {
host.to_string()
};
let mut r = [0u8; 16];
OsRng.fill_bytes(&mut r);
let key = base64::engine::general_purpose::STANDARD.encode(r);
let req = Request::builder()
let mut req_builder = Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", user_agent)
.uri(secret_url.expose_secret().inner.as_str())
.body(())
.expect("building static request always works");
.uri(secret_url.expose_secret().inner.as_str());
Ok(req)
if let Some(host) = secret_url.expose_secret().host() {
let host = secret_url
.expose_secret()
.port()
.map(|port| format!("{host}:{port}"))
.unwrap_or(host.to_string());
req_builder = req_builder.header("Host", host);
}
req_builder
.body(())
.expect("building static request always works")
}
#[derive(Debug, Deserialize, Serialize, Clone)]
@@ -727,7 +765,7 @@ mod tests {
"#;
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
let expected_reply = Payload::<(), ()>::Disconnect {
reason: "token_expired".to_string(),
reason: DisconnectReason::TokenExpired,
};
assert_eq!(actual_reply, expected_reply);
}

View File

@@ -9,7 +9,7 @@ use futures::channel::mpsc;
use futures::{future, FutureExt, SinkExt, StreamExt};
use opentelemetry::{sdk, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use phoenix_channel::{Error, Event, PhoenixChannel, SecureUrl};
use phoenix_channel::{Event, PhoenixChannel, SecureUrl};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use secrecy::{Secret, SecretString};
@@ -501,13 +501,6 @@ where
// Priority 5: Handle portal messages
match self.channel.as_mut().map(|c| c.poll(cx)) {
Some(Poll::Ready(Ok(Event::Disconnect(reason)))) => {
return Poll::Ready(Err(anyhow!("Connection closed by portal: {reason}")));
}
Some(Poll::Ready(Err(Error::Serde(e)))) => {
tracing::warn!(target: "relay", "Failed to deserialize portal message: {e}");
continue; // This is not a hard-error, we can continue.
}
Some(Poll::Ready(Err(e))) => {
return Poll::Ready(Err(anyhow!("Portal connection failed: {e}")));
}