Feat/connlib handle error messages (#1735)

With this PR we handle in the client an error message due to
gateway/relay although rate limiting is needed.

Waiting for #1729 to be merged.
This commit is contained in:
Gabi
2023-07-06 15:47:01 -03:00
committed by GitHub
parent db4bdb0791
commit c817473aef
7 changed files with 102 additions and 33 deletions

View File

@@ -135,8 +135,6 @@ services:
depends_on:
api:
condition: 'service_healthy'
relay:
condition: 'service_healthy'
networks:
app:
ipv4_address: 172.28.0.100

View File

@@ -3,10 +3,10 @@ use std::{sync::Arc, time::Duration};
use crate::messages::{Connect, EgressMessages, InitClient, Messages, Relays};
use boringtun::x25519::StaticSecret;
use libs_common::{
control::PhoenixSenderWithTopic,
error_type::ErrorType::{Fatal, Recoverable},
control::{ErrorInfo, ErrorReply, MessageResult, PhoenixSenderWithTopic},
error_type::ErrorType::{self, Fatal, Recoverable},
messages::{Id, ResourceDescription},
Callbacks, ControlSession, Result,
Callbacks, ControlSession, Error, Result,
};
use async_trait::async_trait;
@@ -19,9 +19,14 @@ impl ControlSignal for ControlSignaler {
self.control_signal
// It's easier if self is not mut
.clone()
.send(EgressMessages::ListRelays {
resource_id: resource.id(),
})
.send_with_ref(
EgressMessages::ListRelays {
resource_id: resource.id(),
},
// The resource id functions as the connection id since we can only have one connection
// outgoing for each resource.
resource.id(),
)
.await?;
Ok(())
}
@@ -40,11 +45,16 @@ struct ControlSignaler {
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(mut self, mut receiver: Receiver<Messages>) {
async fn start(mut self, mut receiver: Receiver<MessageResult<Messages>>) {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some(msg) = receiver.recv() => self.handle_message(msg).await,
Some(msg) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg).await,
Err(msg_reply) => self.handle_error(msg_reply).await,
}
},
_ = interval.tick() => self.stats_event().await,
else => break
}
@@ -126,7 +136,10 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
if let Err(err) = control_signaler
.control_signal
// TODO: create a reference number and keep track for the response
.send_with_ref(EgressMessages::RequestConnection(connection_request), 0)
.send_with_ref(
EgressMessages::RequestConnection(connection_request),
resource_id,
)
.await
{
tunnel.cleanup_connection(resource_id);
@@ -153,6 +166,33 @@ impl<CB: Callbacks + 'static> ControlPlane<CB> {
}
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn handle_error(&mut self, reply_error: ErrorReply) {
if matches!(reply_error.error, ErrorInfo::Offline) {
match reply_error.reference {
Some(reference) => {
let Ok(id) = reference.parse() else {
tracing::error!(
"An offline error came back with a reference to a non-valid resource id"
);
self.tunnel.callbacks().on_error(&Error::ControlProtocolError, ErrorType::Recoverable);
return;
};
// TODO: Rate limit the number of attempts of getting the relays before just trying a local network connection
self.tunnel.cleanup_connection(id);
}
None => {
tracing::error!(
"An offline portal error came without a reference that originated the error"
);
self.tunnel
.callbacks()
.on_error(&Error::ControlProtocolError, ErrorType::Recoverable);
}
}
}
}
#[tracing::instrument(level = "trace", skip(self))]
pub(super) async fn stats_event(&mut self) {
// TODO
@@ -164,7 +204,7 @@ impl<CB: Callbacks + 'static> ControlSession<Messages, CB> for ControlPlane<CB>
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
async fn start(
private_key: StaticSecret,
receiver: Receiver<Messages>,
receiver: Receiver<MessageResult<Messages>>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()> {

View File

@@ -128,7 +128,7 @@ mod test {
#[test]
fn connection_ready_deserialization() {
let message = r#"{
"ref": 0,
"ref": "0",
"topic": "device",
"event": "phx_reply",
"payload": {

View File

@@ -68,7 +68,6 @@ fn make_request(uri: &Url) -> Result<Request> {
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
// TODO: Get OS Info here (os_info crate)
.header("User-Agent", get_user_agent())
.uri(uri.as_str())
.body(())?;
@@ -80,7 +79,7 @@ where
I: DeserializeOwned,
R: DeserializeOwned,
M: From<I> + From<R>,
F: Fn(M) -> Fut,
F: Fn(MessageResult<M>) -> Fut,
Fut: Future<Output = ()> + Send + 'static,
{
/// Starts the tunnel with the parameters given in [Self::new].
@@ -169,16 +168,23 @@ where
match message.into_text() {
Ok(m_str) => match serde_json::from_str::<PhoenixMessage<I, R>>(&m_str) {
Ok(m) => match m.payload {
Payload::Message(m) => handler(m.into()).await,
Payload::Message(m) => handler(Ok(m.into())).await,
Payload::Reply(status) => match status {
ReplyMessage::PhxReply(phx_reply) => match phx_reply {
// TODO: Here we should pass error info to a subscriber
PhxReply::Error(info) => tracing::error!("Portal error: {info:?}"),
PhxReply::Error(info) => {
tracing::warn!("Portal error: {info:?}");
handler(Err(ErrorReply {
error: info,
reference: m.reference,
}))
.await
}
PhxReply::Ok(reply) => match reply {
OkReply::NoMessage(Empty {}) => {
tracing::trace!("Phoenix status message")
}
OkReply::Message(m) => handler(m.into()).await,
OkReply::Message(m) => handler(Ok(m.into())).await,
},
},
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
@@ -232,6 +238,20 @@ where
}
}
/// A result type that is used to communicate to the client/gateway
/// control loop the message received.
pub type MessageResult<M> = std::result::Result<M, ErrorReply>;
/// This struct holds info about an error reply which will be passed
/// to connlib's control plane.
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
pub struct ErrorReply {
/// Information of the error
pub error: ErrorInfo,
/// Reference to the message that caused the error
pub reference: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum Payload<T, R> {
@@ -248,11 +268,11 @@ pub struct PhoenixMessage<T, R> {
#[serde(flatten)]
payload: Payload<T, R>,
#[serde(rename = "ref")]
reference: Option<i32>,
reference: Option<String>,
}
impl<T, R> PhoenixMessage<T, R> {
pub fn new(topic: impl Into<String>, payload: T, reference: Option<i32>) -> Self {
pub fn new(topic: impl Into<String>, payload: T, reference: Option<String>) -> Self {
Self {
topic: topic.into(),
payload: Payload::Message(payload),
@@ -298,9 +318,10 @@ enum OkReply<T> {
NoMessage(Empty),
}
/// This represents the info we have about the error
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum ErrorInfo {
pub enum ErrorInfo {
Reason(String),
Offline,
}
@@ -342,7 +363,11 @@ impl PhoenixSenderWithTopic {
/// Sends a message to the associated topic using a [PhoenixSender] also setting the ref
///
/// See [PhoenixSender::send]
pub async fn send_with_ref(&mut self, payload: impl Serialize, reference: i32) -> Result<()> {
pub async fn send_with_ref(
&mut self,
payload: impl Serialize,
reference: impl ToString,
) -> Result<()> {
self.phoenix_sender
.send_with_ref(&self.topic, payload, reference)
.await
@@ -354,7 +379,7 @@ impl PhoenixSender {
&mut self,
topic: impl Into<String>,
payload: impl Serialize,
reference: Option<i32>,
reference: Option<String>,
) -> Result<()> {
// We don't care about the reply type when serializing
let str = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, reference))?;
@@ -381,9 +406,10 @@ impl PhoenixSender {
&mut self,
topic: impl Into<String>,
payload: impl Serialize,
reference: i32,
reference: impl ToString,
) -> Result<()> {
self.send_internal(topic, payload, Some(reference)).await
self.send_internal(topic, payload, Some(reference.to_string()))
.await
}
/// Join a phoenix topic, meaning that after this method is invoked [PhoenixChannel] will

View File

@@ -13,7 +13,7 @@ use url::Url;
use uuid::Uuid;
use crate::{
control::{PhoenixChannel, PhoenixSenderWithTopic},
control::{MessageResult, PhoenixChannel, PhoenixSenderWithTopic},
error_type::ErrorType,
messages::{Key, ResourceDescription, ResourceDescriptionCidr},
Error, Result,
@@ -26,7 +26,7 @@ pub trait ControlSession<T, CB: Callbacks> {
/// Start control-plane with the given private-key in the background.
async fn start(
private_key: StaticSecret,
receiver: Receiver<T>,
receiver: Receiver<MessageResult<T>>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()>;

View File

@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use boringtun::x25519::StaticSecret;
use firezone_tunnel::{ControlSignal, Tunnel};
use libs_common::{
control::PhoenixSenderWithTopic,
control::{MessageResult, PhoenixSenderWithTopic},
error_type::ErrorType::{Fatal, Recoverable},
messages::ResourceDescription,
Callbacks, ControlSession, Result,
@@ -36,11 +36,16 @@ impl ControlSignal for ControlSignaler {
impl<CB: Callbacks + 'static> ControlPlane<CB> {
#[tracing::instrument(level = "trace", skip(self))]
async fn start(mut self, mut receiver: Receiver<IngressMessages>) {
async fn start(mut self, mut receiver: Receiver<MessageResult<IngressMessages>>) {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
tokio::select! {
Some(msg) = receiver.recv() => self.handle_message(msg).await,
Some(msg) = receiver.recv() => {
match msg {
Ok(msg) => self.handle_message(msg).await,
Err(_msg_reply) => todo!(),
}
},
_ = interval.tick() => self.stats_event().await,
else => break
}
@@ -123,7 +128,7 @@ impl<CB: Callbacks + 'static> ControlSession<IngressMessages, CB> for ControlPla
#[tracing::instrument(level = "trace", skip(private_key, callbacks))]
async fn start(
private_key: StaticSecret,
receiver: Receiver<IngressMessages>,
receiver: Receiver<MessageResult<IngressMessages>>,
control_signal: PhoenixSenderWithTopic,
callbacks: CB,
) -> Result<()> {

View File

@@ -88,7 +88,7 @@ mod device_channel;
mod device_channel;
const RESET_PACKET_COUNT_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_PEERS_TIEMRS_INTERVAL: Duration = Duration::from_secs(1);
const REFRESH_PEERS_TIMERS_INTERVAL: Duration = Duration::from_secs(1);
// Note: Taken from boringtun
const HANDSHAKE_RATE_LIMIT: u64 = 100;
@@ -282,7 +282,7 @@ where
let tunnel = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REFRESH_PEERS_TIEMRS_INTERVAL);
let mut interval = tokio::time::interval(REFRESH_PEERS_TIMERS_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut dst_buf = [0u8; MAX_UDP_SIZE];