mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(phoenix-channel): introduce init function (#2260)
What is common across all our usages of the phoenix channel is that we wait for some kind of `init` message before we fully start-up. We extract this pattern into a dedicated function within the `phoenix-channel` crate. --------- Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
@@ -131,13 +131,6 @@ impl Eventloop {
|
||||
}
|
||||
|
||||
match self.portal.poll(cx)? {
|
||||
Poll::Ready(phoenix_channel::Event::InboundMessage {
|
||||
msg: IngressMessages::Init(_),
|
||||
..
|
||||
}) => {
|
||||
tracing::warn!("Received init message during operation");
|
||||
debug_assert!(false, "Received init message during operation");
|
||||
}
|
||||
Poll::Ready(phoenix_channel::Event::InboundMessage {
|
||||
msg: IngressMessages::RequestConnection(req),
|
||||
..
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::control::ControlSignaler;
|
||||
use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
|
||||
use crate::messages::IngressMessages;
|
||||
use crate::messages::InitGateway;
|
||||
use anyhow::{Context as _, Result};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use boringtun::x25519::StaticSecret;
|
||||
@@ -59,38 +59,18 @@ async fn connect(private_key: StaticSecret, connect_url: Secret<SecureUrl>) -> R
|
||||
let signaler = ControlSignaler::new(control_tx);
|
||||
let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?);
|
||||
|
||||
tracing::debug!("Attempting connection to portal...");
|
||||
let (channel, init) = phoenix_channel::init::<InitGateway, _, _>(
|
||||
connect_url,
|
||||
get_user_agent(),
|
||||
PHOENIX_TOPIC,
|
||||
(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let mut channel =
|
||||
phoenix_channel::PhoenixChannel::connect(connect_url, get_user_agent()).await?;
|
||||
channel.join(PHOENIX_TOPIC, ());
|
||||
|
||||
let channel = loop {
|
||||
match future::poll_fn(|cx| channel.poll(cx))
|
||||
.await
|
||||
.context("portal connection failed")?
|
||||
{
|
||||
phoenix_channel::Event::JoinedRoom { topic } if topic == PHOENIX_TOPIC => {
|
||||
tracing::info!("Joined gateway room on portal")
|
||||
}
|
||||
phoenix_channel::Event::InboundMessage {
|
||||
topic,
|
||||
msg: IngressMessages::Init(init),
|
||||
} => {
|
||||
tracing::info!("Received init message from portal on topic {topic}");
|
||||
|
||||
tunnel
|
||||
.set_interface(&init.interface)
|
||||
.await
|
||||
.context("Failed to set interface")?;
|
||||
|
||||
break channel;
|
||||
}
|
||||
other => {
|
||||
tracing::debug!("Unhandled message from portal: {other:?}");
|
||||
}
|
||||
}
|
||||
};
|
||||
tunnel
|
||||
.set_interface(&init.interface)
|
||||
.await
|
||||
.context("Failed to set interface")?;
|
||||
|
||||
let mut eventloop = Eventloop::new(tunnel, control_rx, channel);
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ pub struct AllowAccess {
|
||||
// TODO: We will need to re-visit webrtc-rs
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum IngressMessages {
|
||||
Init(InitGateway),
|
||||
RequestConnection(RequestConnection),
|
||||
AllowAccess(AllowAccess),
|
||||
IceCandidates(ClientIceCandidates),
|
||||
@@ -135,6 +134,7 @@ pub struct ConnectionReady {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use connlib_shared::{control::PhoenixMessage, messages::Interface};
|
||||
use phoenix_channel::InitMessage;
|
||||
|
||||
use super::{IngressMessages, InitGateway};
|
||||
|
||||
@@ -190,35 +190,18 @@ mod test {
|
||||
}
|
||||
#[test]
|
||||
fn init_phoenix_message() {
|
||||
let m = PhoenixMessage::new(
|
||||
"gateway:83d28051-324e-48fe-98ed-19690899b3b6",
|
||||
IngressMessages::Init(InitGateway {
|
||||
interface: Interface {
|
||||
ipv4: "100.115.164.78".parse().unwrap(),
|
||||
ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(),
|
||||
upstream_dns: vec![],
|
||||
},
|
||||
ipv4_masquerade_enabled: true,
|
||||
ipv6_masquerade_enabled: true,
|
||||
}),
|
||||
None,
|
||||
);
|
||||
|
||||
let message = r#"{
|
||||
"event": "init",
|
||||
"payload": {
|
||||
"interface": {
|
||||
"ipv4": "100.115.164.78",
|
||||
"ipv6": "fd00:2021:1111::2c:f6ab"
|
||||
},
|
||||
"ipv4_masquerade_enabled": true,
|
||||
"ipv6_masquerade_enabled": true
|
||||
let m = InitMessage::Init(InitGateway {
|
||||
interface: Interface {
|
||||
ipv4: "100.115.164.78".parse().unwrap(),
|
||||
ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(),
|
||||
upstream_dns: vec![],
|
||||
},
|
||||
"ref": null,
|
||||
"topic": "gateway:83d28051-324e-48fe-98ed-19690899b3b6"
|
||||
}"#;
|
||||
let ingress_message: PhoenixMessage<IngressMessages, ()> =
|
||||
serde_json::from_str(message).unwrap();
|
||||
ipv4_masquerade_enabled: true,
|
||||
ipv6_masquerade_enabled: true,
|
||||
});
|
||||
|
||||
let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}"#;
|
||||
let ingress_message = serde_json::from_str::<InitMessage<InitGateway>>(message).unwrap();
|
||||
assert_eq!(m, ingress_message);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashSet;
|
||||
use std::{fmt, marker::PhantomData, time::Duration};
|
||||
use std::{fmt, future, marker::PhantomData, time::Duration};
|
||||
|
||||
use base64::Engine;
|
||||
use futures::{FutureExt, SinkExt, StreamExt};
|
||||
@@ -33,6 +33,63 @@ pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
|
||||
pending_join_requests: HashSet<OutboundRequestId>,
|
||||
}
|
||||
|
||||
/// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message.
|
||||
///
|
||||
/// The provided URL must contain a host.
|
||||
/// Additionally, you must already provide any query parameters required for authentication.
|
||||
#[tracing::instrument(level = "debug", skip(payload, secret_url))]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn init<TInitM, TInboundMsg, TOutboundRes>(
|
||||
secret_url: Secret<SecureUrl>,
|
||||
user_agent: String,
|
||||
login_topic: &'static str,
|
||||
payload: impl Serialize,
|
||||
) -> Result<
|
||||
Result<(PhoenixChannel<TInboundMsg, TOutboundRes>, TInitM), UnexpectedEventDuringInit>,
|
||||
Error,
|
||||
>
|
||||
where
|
||||
TInitM: DeserializeOwned + fmt::Debug,
|
||||
TInboundMsg: DeserializeOwned,
|
||||
TOutboundRes: DeserializeOwned,
|
||||
{
|
||||
let mut channel =
|
||||
PhoenixChannel::<InitMessage<TInitM>, ()>::connect(secret_url, user_agent).await?;
|
||||
channel.join(login_topic, payload);
|
||||
|
||||
tracing::info!("Connected to portal, waiting for `init` message");
|
||||
|
||||
let (channel, init_message) = loop {
|
||||
match future::poll_fn(|cx| channel.poll(cx)).await? {
|
||||
Event::JoinedRoom { topic } if topic == login_topic => {
|
||||
tracing::info!("Joined {login_topic} room on portal")
|
||||
}
|
||||
Event::InboundMessage {
|
||||
topic,
|
||||
msg: InitMessage::Init(msg),
|
||||
} if topic == login_topic => {
|
||||
tracing::info!("Received init message from portal");
|
||||
|
||||
break (channel, msg);
|
||||
}
|
||||
Event::HeartbeatSent => {}
|
||||
e => return Ok(Err(UnexpectedEventDuringInit(format!("{e:?}")))),
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Ok((channel.cast(), init_message)))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
pub enum InitMessage<M> {
|
||||
Init(M),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("encountered unexpected event during init: {0}")]
|
||||
pub struct UnexpectedEventDuringInit(String);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("provided URI is missing a host")]
|
||||
@@ -263,6 +320,20 @@ where
|
||||
|
||||
next_id
|
||||
}
|
||||
|
||||
/// Cast this instance of [PhoenixChannel] to new message types.
|
||||
fn cast<TInboundMsgNew, TOutboundResNew>(
|
||||
self,
|
||||
) -> PhoenixChannel<TInboundMsgNew, TOutboundResNew> {
|
||||
PhoenixChannel {
|
||||
stream: self.stream,
|
||||
pending_messages: self.pending_messages,
|
||||
next_request_id: self.next_request_id,
|
||||
next_heartbeat: self.next_heartbeat,
|
||||
_phantom: PhantomData,
|
||||
pending_join_requests: self.pending_join_requests,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -405,7 +476,6 @@ mod tests {
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work.
|
||||
enum Msg {
|
||||
Shout { hello: String },
|
||||
Init {},
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -433,12 +503,18 @@ mod tests {
|
||||
}
|
||||
#[test]
|
||||
fn can_deserialize_init_message() {
|
||||
#[derive(Deserialize, PartialEq, Debug)]
|
||||
struct EmptyInit {}
|
||||
|
||||
let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#;
|
||||
|
||||
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
|
||||
let msg = serde_json::from_str::<PhoenixMessage<InitMessage<EmptyInit>, ()>>(msg).unwrap();
|
||||
|
||||
assert_eq!(msg.topic, "relay");
|
||||
assert_eq!(msg.reference, None);
|
||||
assert_eq!(msg.payload, Payload::Message(Msg::Init {}));
|
||||
assert_eq!(
|
||||
msg.payload,
|
||||
Payload::Message(InitMessage::Init(EmptyInit {}))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,7 +241,7 @@ async fn connect_to_portal(
|
||||
token: &SecretString,
|
||||
mut url: Url,
|
||||
stamp_secret: &SecretString,
|
||||
) -> Result<Option<PhoenixChannel<InboundPortalMessage, ()>>> {
|
||||
) -> Result<Option<PhoenixChannel<(), ()>>> {
|
||||
use secrecy::ExposeSecret;
|
||||
|
||||
if !url.path().is_empty() {
|
||||
@@ -261,57 +261,27 @@ async fn connect_to_portal(
|
||||
.append_pair("ipv6", &public_ip6_addr.to_string());
|
||||
}
|
||||
|
||||
let mut channel = PhoenixChannel::<InboundPortalMessage, ()>::connect(
|
||||
let (channel, Init {}) = phoenix_channel::init::<Init, _, _>(
|
||||
Secret::from(SecureUrl::from_url(url)),
|
||||
format!("relay/{}", env!("CARGO_PKG_VERSION")),
|
||||
)
|
||||
.await
|
||||
.context("Failed to connect to the portal")?;
|
||||
|
||||
tracing::info!("Connected to portal, waiting for init message",);
|
||||
|
||||
channel.join(
|
||||
"relay",
|
||||
JoinMessage {
|
||||
stamp_secret: stamp_secret.expose_secret().to_string(),
|
||||
},
|
||||
);
|
||||
)
|
||||
.await??;
|
||||
|
||||
loop {
|
||||
match future::poll_fn(|cx| channel.poll(cx))
|
||||
.await
|
||||
.context("portal connection failed")?
|
||||
{
|
||||
Event::JoinedRoom { topic } if topic == "relay" => {
|
||||
tracing::info!("Joined relay room on portal")
|
||||
}
|
||||
Event::InboundMessage {
|
||||
topic,
|
||||
msg: InboundPortalMessage::Init {},
|
||||
} => {
|
||||
tracing::info!(
|
||||
"Received init message from portal on topic {topic}, starting relay activities"
|
||||
);
|
||||
return Ok(Some(channel));
|
||||
}
|
||||
other => {
|
||||
tracing::debug!("Unhandled message from portal: {other:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Some(channel))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
struct Init {}
|
||||
|
||||
#[derive(serde::Serialize, PartialEq, Debug)]
|
||||
struct JoinMessage {
|
||||
stamp_secret: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, PartialEq, Debug)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
enum InboundPortalMessage {
|
||||
Init {},
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn make_rng(seed: Option<u64>) -> StdRng {
|
||||
let Some(seed) = seed else {
|
||||
@@ -337,7 +307,7 @@ struct Eventloop<R> {
|
||||
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
|
||||
channel: Option<PhoenixChannel<(), ()>>,
|
||||
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
@@ -350,7 +320,7 @@ where
|
||||
{
|
||||
fn new(
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
|
||||
channel: Option<PhoenixChannel<(), ()>>,
|
||||
public_address: IpStack,
|
||||
) -> Result<Self> {
|
||||
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);
|
||||
@@ -506,13 +476,6 @@ where
|
||||
|
||||
// Priority 5: Handle portal messages
|
||||
match self.channel.as_mut().map(|c| c.poll(cx)) {
|
||||
Some(Poll::Ready(Ok(Event::InboundMessage {
|
||||
msg: InboundPortalMessage::Init {},
|
||||
..
|
||||
}))) => {
|
||||
tracing::warn!("Received init message during operation");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Err(Error::Serde(e)))) => {
|
||||
tracing::warn!("Failed to deserialize portal message: {e}");
|
||||
continue; // This is not a hard-error, we can continue.
|
||||
@@ -535,17 +498,15 @@ where
|
||||
tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::InboundReq {
|
||||
req: InboundPortalMessage::Init {},
|
||||
..
|
||||
}))) => {
|
||||
return Poll::Ready(Err(anyhow!("Init message is not a request")));
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {
|
||||
tracing::debug!("Heartbeat sent to portal");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Pending) | None => {}
|
||||
Some(Poll::Ready(Ok(
|
||||
Event::InboundMessage { msg: (), .. } | Event::InboundReq { req: (), .. },
|
||||
)))
|
||||
| Some(Poll::Pending)
|
||||
| None => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
|
||||
Reference in New Issue
Block a user