mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
577 lines
20 KiB
Rust
577 lines
20 KiB
Rust
//! Control protocol related module.
|
|
//!
|
|
//! This modules contains the logic for handling in and out messages through the control plane.
|
|
//! Handling of the message itself can be found in the other lib crates.
|
|
//!
|
|
//! Entrypoint for this module is [PhoenixChannel].
|
|
use std::{marker::PhantomData, time::Duration};
|
|
|
|
use base64::Engine;
|
|
use futures::{
|
|
channel::mpsc::{channel, Receiver, Sender},
|
|
TryStreamExt,
|
|
};
|
|
use futures_util::{Future, SinkExt, StreamExt, TryFutureExt};
|
|
use rand_core::{OsRng, RngCore};
|
|
use secrecy::Secret;
|
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|
use tokio_stream::StreamExt as _;
|
|
use tokio_tungstenite::{
|
|
connect_async,
|
|
tungstenite::{self, handshake::client::Request},
|
|
};
|
|
use tungstenite::Message;
|
|
use url::Url;
|
|
|
|
use crate::{get_user_agent, Error, Result};
|
|
|
|
const CHANNEL_SIZE: usize = 1_000;
|
|
const HEARTBEAT: Duration = Duration::from_secs(30);
|
|
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35);
|
|
|
|
pub type Reference = String;
|
|
|
|
// TODO: Refactor this PhoenixChannel to use the top-level phoenix-channel crate instead.
|
|
// See https://github.com/firezone/firezone/issues/2158
|
|
pub struct SecureUrl {
|
|
pub inner: Url,
|
|
}
|
|
impl SecureUrl {
|
|
pub fn from_url(url: Url) -> Self {
|
|
Self { inner: url }
|
|
}
|
|
}
|
|
impl secrecy::Zeroize for SecureUrl {
|
|
fn zeroize(&mut self) {
|
|
let placeholder = Url::parse("http://a.com").expect("placeholder URL to be valid");
|
|
let _ = std::mem::replace(&mut self.inner, placeholder);
|
|
}
|
|
}
|
|
/// Main struct to interact with the control-protocol channel.
|
|
///
|
|
/// After creating a new `PhoenixChannel` using [PhoenixChannel::new] you need to
|
|
/// use [start][PhoenixChannel::start] for the channel to do anything.
|
|
///
|
|
/// If you want to send something through the channel you need to obtain a [PhoenixSender] through
|
|
/// [PhoenixChannel::sender], this will already clone the sender so no need to clone it after you obtain it.
|
|
///
|
|
/// When [PhoenixChannel::start] is called a new websocket is created that will listen message from the control plane
|
|
/// based on the parameters passed on [new][PhoenixChannel::new], from then on any messages sent with a sender
|
|
/// obtained by [PhoenixChannel::sender] will be forwarded to the websocket up to the control plane. Ingress messages
|
|
/// will be passed on to the `handler` provided in [PhoenixChannel::new].
|
|
///
|
|
/// The future returned by [PhoenixChannel::start] will finish when the websocket closes (by an error), meaning that if you
|
|
/// `await` it, it will block until you use `close` in a [PhoenixSender], the portal close the connection or something goes wrong.
|
|
pub struct PhoenixChannel<F, I, R, M> {
|
|
secret_url: Secret<SecureUrl>,
|
|
os_version_override: Option<String>,
|
|
handler: F,
|
|
sender: Sender<Message>,
|
|
receiver: Receiver<Message>,
|
|
_phantom: PhantomData<(I, R, M)>,
|
|
}
|
|
|
|
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
|
|
fn make_request(
|
|
secret_url: &Secret<SecureUrl>,
|
|
os_version_override: Option<String>,
|
|
) -> Result<Request> {
|
|
use secrecy::ExposeSecret;
|
|
|
|
let host = secret_url
|
|
.expose_secret()
|
|
.inner
|
|
.host()
|
|
.ok_or(Error::UriError)?;
|
|
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
|
|
format!("{host}:{port}")
|
|
} else {
|
|
host.to_string()
|
|
};
|
|
|
|
let mut r = [0u8; 16];
|
|
OsRng.fill_bytes(&mut r);
|
|
let key = base64::engine::general_purpose::STANDARD.encode(r);
|
|
|
|
let req = Request::builder()
|
|
.method("GET")
|
|
.header("Host", host)
|
|
.header("Connection", "Upgrade")
|
|
.header("Upgrade", "websocket")
|
|
.header("Sec-WebSocket-Version", "13")
|
|
.header("Sec-WebSocket-Key", key)
|
|
.header("User-Agent", get_user_agent(os_version_override))
|
|
.uri(secret_url.expose_secret().inner.as_str())
|
|
.body(())?;
|
|
Ok(req)
|
|
}
|
|
|
|
impl<F, Fut, I, R, M> PhoenixChannel<F, I, R, M>
|
|
where
|
|
I: DeserializeOwned,
|
|
R: DeserializeOwned,
|
|
M: From<I> + From<R>,
|
|
F: Fn(MessageResult<M>, Option<Reference>, String) -> Fut,
|
|
Fut: Future<Output = ()> + Send + 'static,
|
|
{
|
|
/// Starts the tunnel with the parameters given in [Self::new].
|
|
///
|
|
// (Note: we could add a generic list of messages but this is easier)
|
|
/// Additionally, you can add a list of topic to join after connection ASAP.
|
|
///
|
|
/// See [struct-level docs][PhoenixChannel] for more info.
|
|
///
|
|
// TODO: this is not very elegant but it was the easiest way to do reset the exponential backoff for now
|
|
/// Furthermore, it calls the given callback once it connects to the portal.
|
|
#[tracing::instrument(level = "trace", skip(self, after_connection_ends))]
|
|
pub async fn start(
|
|
&mut self,
|
|
topics: Vec<String>,
|
|
after_connection_ends: impl FnOnce(),
|
|
) -> Result<()> {
|
|
tracing::trace!("Trying to connect to portal...");
|
|
|
|
let (ws_stream, _) = connect_async(make_request(
|
|
&self.secret_url,
|
|
self.os_version_override.clone(),
|
|
)?)
|
|
.await?;
|
|
|
|
tracing::trace!("Successfully connected to portal");
|
|
|
|
let (mut write, read) = ws_stream.split();
|
|
|
|
let mut sender = self.sender();
|
|
let Self {
|
|
handler, receiver, ..
|
|
} = self;
|
|
|
|
let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| {
|
|
m.map_err(Error::from)?.map_err(Error::from)
|
|
})
|
|
.try_for_each(|message| async {
|
|
Self::message_process(handler, message).await;
|
|
Ok(())
|
|
});
|
|
|
|
// Would we like to do write.send_all(futures::stream(Message::text(...))) ?
|
|
// yes.
|
|
// but since write is taken by reference rust doesn't believe this future is sendable anymore
|
|
// so this works for now, since we only use it with 1 topic.
|
|
for topic in topics {
|
|
write
|
|
.send(Message::Text(
|
|
// We don't care about the reply type when serializing
|
|
serde_json::to_string(&PhoenixMessage::<_, ()>::new(
|
|
topic,
|
|
EgressControlMessage::PhxJoin(Empty {}),
|
|
None,
|
|
))
|
|
.expect("we should always be able to serialize a join topic message"),
|
|
))
|
|
.await?;
|
|
}
|
|
|
|
// TODO: is Forward cancel safe?
|
|
// I would assume it is and that's the advantage over
|
|
// while let Some(item) = receiver.next().await { write.send(item) } ...
|
|
// but double check this!
|
|
// If it's not cancel safe this means an item can be consumed and never sent.
|
|
// Furthermore can this also happen if write errors out? *that* I'd assume is possible...
|
|
// What option is left? write a new future to forward items.
|
|
// For now we should never assume that an item arrived the portal because we sent it!
|
|
let send_messages = futures::StreamExt::map(receiver, Ok)
|
|
.forward(write)
|
|
.map_err(Error::from);
|
|
|
|
let phoenix_heartbeat = tokio::spawn(async move {
|
|
let mut timer = tokio::time::interval(HEARTBEAT);
|
|
loop {
|
|
timer.tick().await;
|
|
let Ok(_) = sender
|
|
.send("phoenix", EgressControlMessage::Heartbeat(Empty {}))
|
|
.await
|
|
else {
|
|
break;
|
|
};
|
|
}
|
|
});
|
|
|
|
futures_util::pin_mut!(process_messages, send_messages);
|
|
// processing messages should be quick otherwise it'd block sending messages.
|
|
// we could remove this limitation by spawning a separate task for each of these.
|
|
let result = futures::future::select(process_messages, send_messages)
|
|
.await
|
|
.factor_first()
|
|
.0;
|
|
phoenix_heartbeat.abort();
|
|
|
|
after_connection_ends();
|
|
|
|
result?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(level = "trace", skip(handler))]
|
|
async fn message_process(handler: &F, message: tungstenite::Message) {
|
|
tracing::trace!("{message:?}");
|
|
|
|
match message.into_text() {
|
|
Ok(m_str) => match serde_json::from_str::<PhoenixMessage<I, R>>(&m_str) {
|
|
Ok(m) => match m.payload {
|
|
Payload::Message(payload) => {
|
|
handler(Ok(payload.into()), m.reference, m.topic).await
|
|
}
|
|
Payload::Reply(status) => match status {
|
|
ReplyMessage::PhxReply(phx_reply) => match phx_reply {
|
|
// TODO: Here we should pass error info to a subscriber
|
|
PhxReply::Error(info) => {
|
|
tracing::debug!("Portal error: {info:?}");
|
|
handler(Err(ErrorReply { error: info }), m.reference, m.topic).await
|
|
}
|
|
PhxReply::Ok(reply) => match reply {
|
|
OkReply::NoMessage(Empty {}) => {
|
|
tracing::trace!(target: "phoenix_status", "Phoenix status message")
|
|
}
|
|
OkReply::Message(payload) => {
|
|
handler(Ok(payload.into()), m.reference, m.topic).await
|
|
}
|
|
},
|
|
},
|
|
ReplyMessage::PhxError(Empty {}) => tracing::error!("Phoenix error"),
|
|
},
|
|
},
|
|
Err(e) => {
|
|
tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e);
|
|
}
|
|
},
|
|
_ => tracing::error!("Received message that is not text"),
|
|
}
|
|
}
|
|
|
|
/// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal.
|
|
///
|
|
/// Note that for the sender to relay any message will need the future returned [PhoenixChannel::start] to be polled (await it),
|
|
/// and [PhoenixChannel::start] takes `&mut self`, meaning you need to get the sender before running [PhoenixChannel::start].
|
|
pub fn sender(&self) -> PhoenixSender {
|
|
PhoenixSender {
|
|
sender: self.sender.clone(),
|
|
}
|
|
}
|
|
|
|
/// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal for a fixed topic.
|
|
///
|
|
/// For more info see [PhoenixChannel::sender].
|
|
pub fn sender_with_topic(&self, topic: String) -> PhoenixSenderWithTopic {
|
|
PhoenixSenderWithTopic {
|
|
topic,
|
|
phoenix_sender: self.sender(),
|
|
}
|
|
}
|
|
|
|
/// Creates a new [PhoenixChannel] not started yet.
|
|
///
|
|
/// # Parameters:
|
|
/// - `secret_url`: Portal's websocket uri
|
|
/// - `handler`: The handle that will be called for each received message.
|
|
///
|
|
/// For more info see [struct-level docs][PhoenixChannel].
|
|
pub fn new(
|
|
secret_url: Secret<SecureUrl>,
|
|
os_version_override: Option<String>,
|
|
handler: F,
|
|
) -> Self {
|
|
let (sender, receiver) = channel(CHANNEL_SIZE);
|
|
|
|
Self {
|
|
sender,
|
|
receiver,
|
|
secret_url,
|
|
os_version_override,
|
|
handler,
|
|
_phantom: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A result type that is used to communicate to the client/gateway
|
|
/// control loop the message received.
|
|
pub type MessageResult<M> = std::result::Result<M, ErrorReply>;
|
|
|
|
/// This struct holds info about an error reply which will be passed
|
|
/// to connlib's control plane.
|
|
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
|
pub struct ErrorReply {
|
|
/// Information of the error
|
|
pub error: ErrorInfo,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
|
#[serde(untagged)]
|
|
enum Payload<T, R> {
|
|
// We might want other type for the reply message
|
|
// but that makes everything even more convoluted!
|
|
// and we need to think how to make this whole mess less convoluted.
|
|
Reply(ReplyMessage<R>),
|
|
Message(T),
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
|
|
pub struct PhoenixMessage<T, R> {
|
|
// TODO: we should use a newtype pattern for topics
|
|
topic: String,
|
|
#[serde(flatten)]
|
|
payload: Payload<T, R>,
|
|
#[serde(rename = "ref")]
|
|
reference: Option<String>,
|
|
}
|
|
|
|
impl<T, R> PhoenixMessage<T, R> {
|
|
pub fn new(topic: impl Into<String>, payload: T, reference: Option<String>) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
payload: Payload::Message(payload),
|
|
reference,
|
|
}
|
|
}
|
|
|
|
pub fn new_ok_reply(
|
|
topic: impl Into<String>,
|
|
payload: R,
|
|
reference: impl Into<Option<String>>,
|
|
) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
// There has to be a better way :\
|
|
payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message(
|
|
payload,
|
|
)))),
|
|
reference: reference.into(),
|
|
}
|
|
}
|
|
|
|
pub fn new_err_reply(
|
|
topic: impl Into<String>,
|
|
error: ErrorInfo,
|
|
reference: impl Into<Option<String>>,
|
|
) -> Self {
|
|
Self {
|
|
topic: topic.into(),
|
|
// There has to be a better way :\
|
|
payload: Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(error))),
|
|
reference: reference.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
|
|
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
|
|
#[serde(deny_unknown_fields)]
|
|
struct Empty {}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
|
enum EgressControlMessage {
|
|
PhxJoin(Empty),
|
|
Heartbeat(Empty),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
|
enum ReplyMessage<T> {
|
|
PhxReply(PhxReply<T>),
|
|
PhxError(Empty),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(untagged)]
|
|
enum OkReply<T> {
|
|
Message(T),
|
|
NoMessage(Empty),
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
pub struct UnknownError(pub String);
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
pub enum KnownError {
|
|
#[serde(rename = "unmatched topic")]
|
|
UnmatchedTopic,
|
|
#[serde(rename = "token_expired")]
|
|
TokenExpired,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(untagged)]
|
|
pub enum Reason {
|
|
Known(KnownError),
|
|
Unknown(UnknownError),
|
|
}
|
|
|
|
/// This represents the info we have about the error
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum ErrorInfo {
|
|
Reason(Reason),
|
|
Offline,
|
|
Disabled,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
|
|
enum PhxReply<T> {
|
|
Ok(OkReply<T>),
|
|
Error(ErrorInfo),
|
|
}
|
|
|
|
/// You can use this sender to send messages through a `PhoenixChannel`.
|
|
///
|
|
/// Messages won't be sent unless [PhoenixChannel::start] is running, internally
|
|
/// this sends messages through a future channel that are forwrarded then in [PhoenixChannel] event loop
|
|
#[derive(Clone, Debug)]
|
|
pub struct PhoenixSender {
|
|
sender: Sender<Message>,
|
|
}
|
|
|
|
/// Like a [PhoenixSender] with a fixed topic for simplicity
|
|
///
|
|
/// You can obtain it through [PhoenixChannel::sender_with_topic]
|
|
/// See [PhoenixSender] docs and use that if you need more control.
|
|
#[derive(Clone, Debug)]
|
|
pub struct PhoenixSenderWithTopic {
|
|
phoenix_sender: PhoenixSender,
|
|
topic: String,
|
|
}
|
|
|
|
impl PhoenixSenderWithTopic {
|
|
/// Sends a message to the associated topic using a [PhoenixSender]
|
|
///
|
|
/// See [PhoenixSender::send]
|
|
pub async fn send(&mut self, payload: impl Serialize) -> Result<()> {
|
|
self.phoenix_sender.send(&self.topic, payload).await
|
|
}
|
|
|
|
/// Sends a message to the associated topic using a [PhoenixSender] also setting the ref
|
|
///
|
|
/// See [PhoenixSender::send]
|
|
pub async fn send_with_ref(
|
|
&mut self,
|
|
payload: impl Serialize,
|
|
reference: impl ToString,
|
|
) -> Result<()> {
|
|
self.phoenix_sender
|
|
.send_with_ref(&self.topic, payload, reference)
|
|
.await
|
|
}
|
|
|
|
pub fn get_sender(&mut self) -> &mut PhoenixSender {
|
|
&mut self.phoenix_sender
|
|
}
|
|
}
|
|
|
|
impl PhoenixSender {
|
|
async fn send_internal(
|
|
&mut self,
|
|
topic: impl Into<String>,
|
|
payload: impl Serialize,
|
|
reference: Option<String>,
|
|
) -> Result<()> {
|
|
// We don't care about the reply type when serializing
|
|
let str = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, reference))?;
|
|
self.sender.send(Message::text(str)).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Sends a message upstream to a connected [PhoenixChannel].
|
|
///
|
|
/// # Parameters
|
|
/// - topic: Phoenix topic
|
|
/// - payload: Message's payload
|
|
pub async fn send(&mut self, topic: impl Into<String>, payload: impl Serialize) -> Result<()> {
|
|
self.send_internal(topic, payload, None).await
|
|
}
|
|
|
|
/// Sends a message upstream to a connected [PhoenixChannel] using the given ref number.
|
|
///
|
|
/// # Parameters
|
|
/// - topic: Phoenix topic
|
|
/// - payload: Message's payload
|
|
/// - reference: Reference number used in the message, if the message has a response that same number will be used
|
|
pub async fn send_with_ref(
|
|
&mut self,
|
|
topic: impl Into<String>,
|
|
payload: impl Serialize,
|
|
reference: impl ToString,
|
|
) -> Result<()> {
|
|
self.send_internal(topic, payload, Some(reference.to_string()))
|
|
.await
|
|
}
|
|
|
|
/// Join a phoenix topic, meaning that after this method is invoked [PhoenixChannel] will
|
|
/// receive messages from that topic, given that upstream accepts you into the given topic.
|
|
pub async fn join_topic(&mut self, topic: impl Into<String>) -> Result<()> {
|
|
self.send(topic, EgressControlMessage::PhxJoin(Empty {}))
|
|
.await
|
|
}
|
|
|
|
/// Closes the [PhoenixChannel]
|
|
pub async fn close(&mut self) -> Result<()> {
|
|
self.sender.send(Message::Close(None)).await?;
|
|
self.sender.close().await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::control::{
|
|
ErrorInfo, KnownError, Payload, PhxReply::Error, Reason, ReplyMessage::PhxReply,
|
|
UnknownError,
|
|
};
|
|
|
|
#[test]
|
|
fn unmatched_topic_reply() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event":"phx_reply",
|
|
"ref":"12",
|
|
"topic":"client",
|
|
"payload":{
|
|
"status":"error",
|
|
"response":{
|
|
"reason":"unmatched topic"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason(
|
|
Reason::Known(KnownError::UnmatchedTopic),
|
|
))));
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
|
|
#[test]
|
|
fn unexpected_error_reply() {
|
|
let actual_reply = r#"
|
|
{
|
|
"event":"phx_reply",
|
|
"ref":"12",
|
|
"topic":"client",
|
|
"payload":{
|
|
"status":"error",
|
|
"response":{
|
|
"reason":"bad reply"
|
|
}
|
|
}
|
|
}
|
|
"#;
|
|
let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap();
|
|
let expected_reply = Payload::<(), ()>::Reply(PhxReply(Error(ErrorInfo::Reason(
|
|
Reason::Unknown(UnknownError("bad reply".to_string())),
|
|
))));
|
|
assert_eq!(actual_reply, expected_reply);
|
|
}
|
|
}
|