mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(phoenix-channel): reduce Error to fatal errors (#4015)
As part of doing https://github.com/firezone/firezone/pull/3682, we noticed that the handling of errors up to the clients needs to differentiate between fatal errors that require clearing the token vs not. Upon closer inspection of `phoenix_channel::Error`, it becomes obvious that the current design is not good here. In particular, we handle certain errors with retries internally but still expose those same errors. To make this more obvious, we reduce the public `Error` to the variants that are actually fatal. Those can really only be three: - HTTP client errors (those are by definition non-retryable) - Token expired - We have reached our max number of retries
This commit is contained in:
@@ -3,7 +3,7 @@ use crate::messages::{
|
||||
EgressMessages, IngressMessages, RejectAccess, RequestConnection,
|
||||
};
|
||||
use crate::CallbackHandler;
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use anyhow::{bail, Result};
|
||||
use boringtun::x25519::PublicKey;
|
||||
use connlib_shared::{
|
||||
messages::{GatewayResponse, ResourceAccepted, ResourceDescription},
|
||||
@@ -222,9 +222,6 @@ impl Eventloop {
|
||||
// TODO: Handle `init` message during operation.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(phoenix_channel::Event::Disconnect(reason)) => {
|
||||
return Poll::Ready(Err(anyhow!("Disconnected by portal: {reason}")));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ use secrecy::{CloneableSecret, ExposeSecret as _, Secret};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::tungstenite::http::StatusCode;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{handshake::client::Request, Message},
|
||||
@@ -45,7 +46,9 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
||||
|
||||
enum State {
|
||||
Connected(WebSocketStream<MaybeTlsStream<TcpStream>>),
|
||||
Connecting(BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error>>),
|
||||
Connecting(
|
||||
BoxFuture<'static, Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message.
|
||||
@@ -113,21 +116,53 @@ pub struct UnexpectedEventDuringInit(String);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("provided URI is missing a host")]
|
||||
MissingHost,
|
||||
#[error("websocket failed")]
|
||||
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
|
||||
#[error("failed to serialize message")]
|
||||
Serde(#[from] serde_json::Error),
|
||||
#[error("server sent a reply without a reference")]
|
||||
MissingReplyId,
|
||||
#[error("server did not reply to our heartbeat")]
|
||||
#[error("client error: {0}")]
|
||||
ClientError(StatusCode),
|
||||
#[error("token expired")]
|
||||
TokenExpired,
|
||||
#[error("max retries reached")]
|
||||
MaxRetriesReached,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn is_authentication_error(&self) -> bool {
|
||||
match self {
|
||||
Error::ClientError(s) => s == &StatusCode::UNAUTHORIZED || s == &StatusCode::FORBIDDEN,
|
||||
Error::TokenExpired => true,
|
||||
Error::MaxRetriesReached => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum InternalError {
|
||||
WebSocket(tokio_tungstenite::tungstenite::Error),
|
||||
Serde(serde_json::Error),
|
||||
MissedHeartbeat,
|
||||
#[error("connection close message")]
|
||||
CloseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash)]
|
||||
impl fmt::Display for InternalError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => {
|
||||
let status = http.status();
|
||||
let body = http
|
||||
.body()
|
||||
.as_deref()
|
||||
.map(String::from_utf8_lossy)
|
||||
.unwrap_or_default();
|
||||
|
||||
write!(f, "http error: {status} - {body}")
|
||||
}
|
||||
InternalError::WebSocket(e) => write!(f, "websocket connection failed: {e}"),
|
||||
InternalError::Serde(e) => write!(f, "failed to deserialize message: {e}"),
|
||||
InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"),
|
||||
InternalError::CloseMessage => write!(f, "portal closed the websocket connection"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Deserialize, Serialize)]
|
||||
pub struct OutboundRequestId(u64);
|
||||
|
||||
impl OutboundRequestId {
|
||||
@@ -135,6 +170,13 @@ impl OutboundRequestId {
|
||||
pub(crate) fn new(id: u64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
|
||||
/// Internal function to make a copy.
|
||||
///
|
||||
/// Not exposed publicly because these IDs are meant to be unique.
|
||||
pub(crate) fn copy(&self) -> Self {
|
||||
Self(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for OutboundRequestId {
|
||||
@@ -168,6 +210,10 @@ impl SecureUrl {
|
||||
pub fn host(&self) -> Option<&str> {
|
||||
self.inner.host_str()
|
||||
}
|
||||
|
||||
pub fn port(&self) -> Option<u16> {
|
||||
self.inner.port()
|
||||
}
|
||||
}
|
||||
|
||||
impl CloneableSecret for SecureUrl {}
|
||||
@@ -203,7 +249,9 @@ where
|
||||
secret_url: secret_url.clone(),
|
||||
user_agent: user_agent.clone(),
|
||||
state: State::Connecting(Box::pin(async move {
|
||||
let (stream, _) = connect_async(make_request(secret_url, user_agent)?).await?;
|
||||
let (stream, _) = connect_async(make_request(secret_url, user_agent))
|
||||
.await
|
||||
.map_err(InternalError::WebSocket)?;
|
||||
|
||||
Ok(stream)
|
||||
})),
|
||||
@@ -255,41 +303,28 @@ where
|
||||
|
||||
continue;
|
||||
}
|
||||
Err(InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(
|
||||
r,
|
||||
))) if r.status().is_client_error() => {
|
||||
return Poll::Ready(Err(Error::ClientError(r.status())));
|
||||
}
|
||||
Err(e) => {
|
||||
if let Error::WebSocket(tokio_tungstenite::tungstenite::Error::Http(r)) = &e
|
||||
{
|
||||
let status = r.status();
|
||||
|
||||
if status.is_client_error() {
|
||||
let body = r
|
||||
.body()
|
||||
.as_deref()
|
||||
.map(String::from_utf8_lossy)
|
||||
.unwrap_or_default();
|
||||
|
||||
tracing::warn!(
|
||||
"Fatal client error ({status}) in portal connection: {body}"
|
||||
);
|
||||
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
};
|
||||
|
||||
let Some(backoff) = self.reconnect_backoff.next_backoff() else {
|
||||
tracing::warn!("Reconnect backoff expired");
|
||||
return Poll::Ready(Err(e));
|
||||
return Poll::Ready(Err(Error::MaxRetriesReached));
|
||||
};
|
||||
|
||||
let secret_url = self.secret_url.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
|
||||
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e));
|
||||
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
|
||||
let (stream, _) =
|
||||
connect_async(make_request(secret_url, user_agent)?).await?;
|
||||
let (stream, _) = connect_async(make_request(secret_url, user_agent))
|
||||
.await
|
||||
.map_err(InternalError::WebSocket)?;
|
||||
|
||||
Ok(stream)
|
||||
}));
|
||||
@@ -306,7 +341,7 @@ where
|
||||
match stream.start_send_unpin(Message::Text(message)) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
@@ -329,7 +364,7 @@ where
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) if e.is_io() || e.is_eof() => {
|
||||
self.reconnect_on_transient_error(Error::Serde(e));
|
||||
self.reconnect_on_transient_error(InternalError::Serde(e));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -338,26 +373,25 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
match message.payload {
|
||||
Payload::Message(msg) => {
|
||||
match (message.payload, message.reference) {
|
||||
(Payload::Message(msg), _) => {
|
||||
return Poll::Ready(Ok(Event::InboundMessage {
|
||||
topic: message.topic,
|
||||
msg,
|
||||
}))
|
||||
}
|
||||
Payload::Reply(Reply::Error { reason }) => {
|
||||
(Payload::Reply(_), None) => {
|
||||
tracing::warn!("Discarding reply because server omitted reference");
|
||||
continue;
|
||||
}
|
||||
(Payload::Reply(Reply::Error { reason }), Some(req_id)) => {
|
||||
return Poll::Ready(Ok(Event::ErrorResponse {
|
||||
topic: message.topic,
|
||||
req_id: OutboundRequestId(
|
||||
message.reference.ok_or(Error::MissingReplyId)?,
|
||||
),
|
||||
req_id,
|
||||
reason,
|
||||
}));
|
||||
}
|
||||
Payload::Reply(Reply::Ok(OkReply::Message(reply))) => {
|
||||
let req_id =
|
||||
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
|
||||
|
||||
(Payload::Reply(Reply::Ok(OkReply::Message(reply))), Some(req_id)) => {
|
||||
if self.pending_join_requests.remove(&req_id) {
|
||||
tracing::info!("Joined {} room on portal", message.topic);
|
||||
|
||||
@@ -373,41 +407,39 @@ where
|
||||
res: reply,
|
||||
}));
|
||||
}
|
||||
Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))) => {
|
||||
let id =
|
||||
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
|
||||
|
||||
if self.heartbeat.maybe_handle_reply(id) {
|
||||
(Payload::Reply(Reply::Ok(OkReply::NoMessage(Empty {}))), Some(req_id)) => {
|
||||
if self.heartbeat.maybe_handle_reply(req_id.copy()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
"Received empty reply for request {:?}",
|
||||
message.reference
|
||||
);
|
||||
tracing::trace!("Received empty reply for request {req_id:?}");
|
||||
|
||||
continue;
|
||||
}
|
||||
Payload::Error(Empty {}) => {
|
||||
return Poll::Ready(Ok(Event::ErrorResponse {
|
||||
topic: message.topic,
|
||||
req_id: OutboundRequestId(
|
||||
message.reference.ok_or(Error::MissingReplyId)?,
|
||||
),
|
||||
reason: ErrorReply::Other,
|
||||
}))
|
||||
}
|
||||
Payload::Close(Empty {}) => {
|
||||
self.reconnect_on_transient_error(Error::CloseMessage);
|
||||
(Payload::Error(Empty {}), reference) => {
|
||||
tracing::debug!(
|
||||
?reference,
|
||||
topic = &message.topic,
|
||||
"Received empty error response"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Payload::Disconnect { reason } => {
|
||||
return Poll::Ready(Ok(Event::Disconnect(reason)));
|
||||
(Payload::Close(Empty {}), _) => {
|
||||
self.reconnect_on_transient_error(InternalError::CloseMessage);
|
||||
continue;
|
||||
}
|
||||
(
|
||||
Payload::Disconnect {
|
||||
reason: DisconnectReason::TokenExpired,
|
||||
},
|
||||
_,
|
||||
) => {
|
||||
return Poll::Ready(Err(Error::TokenExpired));
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
||||
continue;
|
||||
}
|
||||
_ => (),
|
||||
@@ -422,7 +454,7 @@ where
|
||||
return Poll::Ready(Ok(Event::HeartbeatSent));
|
||||
}
|
||||
Poll::Ready(Err(MissedLastHeartbeat {})) => {
|
||||
self.reconnect_on_transient_error(Error::MissedHeartbeat);
|
||||
self.reconnect_on_transient_error(InternalError::MissedHeartbeat);
|
||||
continue;
|
||||
}
|
||||
_ => (),
|
||||
@@ -434,7 +466,7 @@ where
|
||||
tracing::trace!("Flushed websocket");
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
self.reconnect_on_transient_error(InternalError::WebSocket(e));
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
@@ -447,7 +479,7 @@ where
|
||||
/// Sets the channels state to [`State::Connecting`] with the given error.
|
||||
///
|
||||
/// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error.
|
||||
fn reconnect_on_transient_error(&mut self, e: Error) {
|
||||
fn reconnect_on_transient_error(&mut self, e: InternalError) {
|
||||
self.state = State::Connecting(future::ready(Err(e)).boxed())
|
||||
}
|
||||
|
||||
@@ -459,19 +491,23 @@ where
|
||||
let request_id = self.fetch_add_request_id();
|
||||
|
||||
// We don't care about the reply type when serializing
|
||||
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id))
|
||||
.expect("we should always be able to serialize a join topic message");
|
||||
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new(
|
||||
topic,
|
||||
payload,
|
||||
request_id.copy(),
|
||||
))
|
||||
.expect("we should always be able to serialize a join topic message");
|
||||
|
||||
self.pending_messages.push_back(msg);
|
||||
|
||||
OutboundRequestId(request_id)
|
||||
request_id
|
||||
}
|
||||
|
||||
fn fetch_add_request_id(&mut self) -> u64 {
|
||||
fn fetch_add_request_id(&mut self) -> OutboundRequestId {
|
||||
let next_id = self.next_request_id;
|
||||
self.next_request_id += 1;
|
||||
|
||||
next_id
|
||||
OutboundRequestId(next_id)
|
||||
}
|
||||
|
||||
/// Cast this instance of [PhoenixChannel] to new message types.
|
||||
@@ -516,17 +552,16 @@ pub enum Event<TInboundMsg, TOutboundRes> {
|
||||
topic: String,
|
||||
msg: TInboundMsg,
|
||||
},
|
||||
Disconnect(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct PhoenixMessage<T, R> {
|
||||
// TODO: we should use a newtype pattern for topics
|
||||
topic: String,
|
||||
#[serde(flatten)]
|
||||
payload: Payload<T, R>,
|
||||
#[serde(rename = "ref")]
|
||||
reference: Option<u64>,
|
||||
reference: Option<OutboundRequestId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
@@ -539,7 +574,7 @@ enum Payload<T, R> {
|
||||
#[serde(rename = "phx_close")]
|
||||
Close(Empty),
|
||||
#[serde(rename = "disconnect")]
|
||||
Disconnect { reason: String },
|
||||
Disconnect { reason: DisconnectReason },
|
||||
#[serde(untagged)]
|
||||
Message(T),
|
||||
}
|
||||
@@ -569,7 +604,6 @@ enum OkReply<T> {
|
||||
pub enum ErrorReply {
|
||||
#[serde(rename = "unmatched topic")]
|
||||
UnmatchedTopic,
|
||||
TokenExpired,
|
||||
NotFound,
|
||||
Offline,
|
||||
Disabled,
|
||||
@@ -577,8 +611,14 @@ pub enum ErrorReply {
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DisconnectReason {
|
||||
TokenExpired,
|
||||
}
|
||||
|
||||
impl<T, R> PhoenixMessage<T, R> {
|
||||
pub fn new(topic: impl Into<String>, payload: T, reference: u64) -> Self {
|
||||
pub fn new(topic: impl Into<String>, payload: T, reference: OutboundRequestId) -> Self {
|
||||
Self {
|
||||
topic: topic.into(),
|
||||
payload: Payload::Message(payload),
|
||||
@@ -588,37 +628,35 @@ impl<T, R> PhoenixMessage<T, R> {
|
||||
}
|
||||
|
||||
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
|
||||
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Result<Request, Error> {
|
||||
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Request {
|
||||
use secrecy::ExposeSecret;
|
||||
|
||||
let host = secret_url
|
||||
.expose_secret()
|
||||
.inner
|
||||
.host()
|
||||
.ok_or(Error::MissingHost)?;
|
||||
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
|
||||
format!("{host}:{port}")
|
||||
} else {
|
||||
host.to_string()
|
||||
};
|
||||
|
||||
let mut r = [0u8; 16];
|
||||
OsRng.fill_bytes(&mut r);
|
||||
let key = base64::engine::general_purpose::STANDARD.encode(r);
|
||||
|
||||
let req = Request::builder()
|
||||
let mut req_builder = Request::builder()
|
||||
.method("GET")
|
||||
.header("Host", host)
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Sec-WebSocket-Key", key)
|
||||
.header("User-Agent", user_agent)
|
||||
.uri(secret_url.expose_secret().inner.as_str())
|
||||
.body(())
|
||||
.expect("building static request always works");
|
||||
.uri(secret_url.expose_secret().inner.as_str());
|
||||
|
||||
Ok(req)
|
||||
if let Some(host) = secret_url.expose_secret().host() {
|
||||
let host = secret_url
|
||||
.expose_secret()
|
||||
.port()
|
||||
.map(|port| format!("{host}:{port}"))
|
||||
.unwrap_or(host.to_string());
|
||||
|
||||
req_builder = req_builder.header("Host", host);
|
||||
}
|
||||
|
||||
req_builder
|
||||
.body(())
|
||||
.expect("building static request always works")
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
@@ -727,7 +765,7 @@ mod tests {
|
||||
"#;
|
||||
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
||||
let expected_reply = Payload::<(), ()>::Disconnect {
|
||||
reason: "token_expired".to_string(),
|
||||
reason: DisconnectReason::TokenExpired,
|
||||
};
|
||||
assert_eq!(actual_reply, expected_reply);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::channel::mpsc;
|
||||
use futures::{future, FutureExt, SinkExt, StreamExt};
|
||||
use opentelemetry::{sdk, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use phoenix_channel::{Error, Event, PhoenixChannel, SecureUrl};
|
||||
use phoenix_channel::{Event, PhoenixChannel, SecureUrl};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use secrecy::{Secret, SecretString};
|
||||
@@ -501,13 +501,6 @@ where
|
||||
|
||||
// Priority 5: Handle portal messages
|
||||
match self.channel.as_mut().map(|c| c.poll(cx)) {
|
||||
Some(Poll::Ready(Ok(Event::Disconnect(reason)))) => {
|
||||
return Poll::Ready(Err(anyhow!("Connection closed by portal: {reason}")));
|
||||
}
|
||||
Some(Poll::Ready(Err(Error::Serde(e)))) => {
|
||||
tracing::warn!(target: "relay", "Failed to deserialize portal message: {e}");
|
||||
continue; // This is not a hard-error, we can continue.
|
||||
}
|
||||
Some(Poll::Ready(Err(e))) => {
|
||||
return Poll::Ready(Err(anyhow!("Portal connection failed: {e}")));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user