refactor(phoenix-channel): introduce init function (#2260)

What is common across all our usages of the phoenix channel is that we
wait for some kind of `init` message before we fully start-up. We
extract this pattern into a dedicated function within the
`phoenix-channel` crate.

---------

Signed-off-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Thomas Eizinger
2023-10-10 07:45:46 +11:00
committed by GitHub
parent 30a681ad6b
commit 72d6942a71
5 changed files with 119 additions and 126 deletions

View File

@@ -131,13 +131,6 @@ impl Eventloop {
}
match self.portal.poll(cx)? {
Poll::Ready(phoenix_channel::Event::InboundMessage {
msg: IngressMessages::Init(_),
..
}) => {
tracing::warn!("Received init message during operation");
debug_assert!(false, "Received init message during operation");
}
Poll::Ready(phoenix_channel::Event::InboundMessage {
msg: IngressMessages::RequestConnection(req),
..

View File

@@ -1,6 +1,6 @@
use crate::control::ControlSignaler;
use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use crate::messages::IngressMessages;
use crate::messages::InitGateway;
use anyhow::{Context as _, Result};
use backoff::ExponentialBackoffBuilder;
use boringtun::x25519::StaticSecret;
@@ -59,38 +59,18 @@ async fn connect(private_key: StaticSecret, connect_url: Secret<SecureUrl>) -> R
let signaler = ControlSignaler::new(control_tx);
let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?);
tracing::debug!("Attempting connection to portal...");
let (channel, init) = phoenix_channel::init::<InitGateway, _, _>(
connect_url,
get_user_agent(),
PHOENIX_TOPIC,
(),
)
.await??;
let mut channel =
phoenix_channel::PhoenixChannel::connect(connect_url, get_user_agent()).await?;
channel.join(PHOENIX_TOPIC, ());
let channel = loop {
match future::poll_fn(|cx| channel.poll(cx))
.await
.context("portal connection failed")?
{
phoenix_channel::Event::JoinedRoom { topic } if topic == PHOENIX_TOPIC => {
tracing::info!("Joined gateway room on portal")
}
phoenix_channel::Event::InboundMessage {
topic,
msg: IngressMessages::Init(init),
} => {
tracing::info!("Received init message from portal on topic {topic}");
tunnel
.set_interface(&init.interface)
.await
.context("Failed to set interface")?;
break channel;
}
other => {
tracing::debug!("Unhandled message from portal: {other:?}");
}
}
};
tunnel
.set_interface(&init.interface)
.await
.context("Failed to set interface")?;
let mut eventloop = Eventloop::new(tunnel, control_rx, channel);

View File

@@ -89,7 +89,6 @@ pub struct AllowAccess {
// TODO: We will need to re-visit webrtc-rs
#[allow(clippy::large_enum_variant)]
pub enum IngressMessages {
Init(InitGateway),
RequestConnection(RequestConnection),
AllowAccess(AllowAccess),
IceCandidates(ClientIceCandidates),
@@ -135,6 +134,7 @@ pub struct ConnectionReady {
#[cfg(test)]
mod test {
use connlib_shared::{control::PhoenixMessage, messages::Interface};
use phoenix_channel::InitMessage;
use super::{IngressMessages, InitGateway};
@@ -190,35 +190,18 @@ mod test {
}
#[test]
fn init_phoenix_message() {
let m = PhoenixMessage::new(
"gateway:83d28051-324e-48fe-98ed-19690899b3b6",
IngressMessages::Init(InitGateway {
interface: Interface {
ipv4: "100.115.164.78".parse().unwrap(),
ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(),
upstream_dns: vec![],
},
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true,
}),
None,
);
let message = r#"{
"event": "init",
"payload": {
"interface": {
"ipv4": "100.115.164.78",
"ipv6": "fd00:2021:1111::2c:f6ab"
},
"ipv4_masquerade_enabled": true,
"ipv6_masquerade_enabled": true
let m = InitMessage::Init(InitGateway {
interface: Interface {
ipv4: "100.115.164.78".parse().unwrap(),
ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(),
upstream_dns: vec![],
},
"ref": null,
"topic": "gateway:83d28051-324e-48fe-98ed-19690899b3b6"
}"#;
let ingress_message: PhoenixMessage<IngressMessages, ()> =
serde_json::from_str(message).unwrap();
ipv4_masquerade_enabled: true,
ipv6_masquerade_enabled: true,
});
let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}"#;
let ingress_message = serde_json::from_str::<InitMessage<InitGateway>>(message).unwrap();
assert_eq!(m, ingress_message);
}
}

View File

@@ -1,5 +1,5 @@
use std::collections::HashSet;
use std::{fmt, marker::PhantomData, time::Duration};
use std::{fmt, future, marker::PhantomData, time::Duration};
use base64::Engine;
use futures::{FutureExt, SinkExt, StreamExt};
@@ -33,6 +33,63 @@ pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
pending_join_requests: HashSet<OutboundRequestId>,
}
/// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message.
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
#[tracing::instrument(level = "debug", skip(payload, secret_url))]
#[allow(clippy::type_complexity)]
pub async fn init<TInitM, TInboundMsg, TOutboundRes>(
secret_url: Secret<SecureUrl>,
user_agent: String,
login_topic: &'static str,
payload: impl Serialize,
) -> Result<
Result<(PhoenixChannel<TInboundMsg, TOutboundRes>, TInitM), UnexpectedEventDuringInit>,
Error,
>
where
TInitM: DeserializeOwned + fmt::Debug,
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
{
let mut channel =
PhoenixChannel::<InitMessage<TInitM>, ()>::connect(secret_url, user_agent).await?;
channel.join(login_topic, payload);
tracing::info!("Connected to portal, waiting for `init` message");
let (channel, init_message) = loop {
match future::poll_fn(|cx| channel.poll(cx)).await? {
Event::JoinedRoom { topic } if topic == login_topic => {
tracing::info!("Joined {login_topic} room on portal")
}
Event::InboundMessage {
topic,
msg: InitMessage::Init(msg),
} if topic == login_topic => {
tracing::info!("Received init message from portal");
break (channel, msg);
}
Event::HeartbeatSent => {}
e => return Ok(Err(UnexpectedEventDuringInit(format!("{e:?}")))),
}
};
Ok(Ok((channel.cast(), init_message)))
}
#[derive(serde::Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
pub enum InitMessage<M> {
Init(M),
}
#[derive(Debug, thiserror::Error)]
#[error("encountered unexpected event during init: {0}")]
pub struct UnexpectedEventDuringInit(String);
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("provided URI is missing a host")]
@@ -263,6 +320,20 @@ where
next_id
}
/// Cast this instance of [PhoenixChannel] to new message types.
fn cast<TInboundMsgNew, TOutboundResNew>(
self,
) -> PhoenixChannel<TInboundMsgNew, TOutboundResNew> {
PhoenixChannel {
stream: self.stream,
pending_messages: self.pending_messages,
next_request_id: self.next_request_id,
next_heartbeat: self.next_heartbeat,
_phantom: PhantomData,
pending_join_requests: self.pending_join_requests,
}
}
}
#[derive(Debug)]
@@ -405,7 +476,6 @@ mod tests {
#[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work.
enum Msg {
Shout { hello: String },
Init {},
}
#[test]
@@ -433,12 +503,18 @@ mod tests {
}
#[test]
fn can_deserialize_init_message() {
#[derive(Deserialize, PartialEq, Debug)]
struct EmptyInit {}
let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#;
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
let msg = serde_json::from_str::<PhoenixMessage<InitMessage<EmptyInit>, ()>>(msg).unwrap();
assert_eq!(msg.topic, "relay");
assert_eq!(msg.reference, None);
assert_eq!(msg.payload, Payload::Message(Msg::Init {}));
assert_eq!(
msg.payload,
Payload::Message(InitMessage::Init(EmptyInit {}))
);
}
}

View File

@@ -241,7 +241,7 @@ async fn connect_to_portal(
token: &SecretString,
mut url: Url,
stamp_secret: &SecretString,
) -> Result<Option<PhoenixChannel<InboundPortalMessage, ()>>> {
) -> Result<Option<PhoenixChannel<(), ()>>> {
use secrecy::ExposeSecret;
if !url.path().is_empty() {
@@ -261,57 +261,27 @@ async fn connect_to_portal(
.append_pair("ipv6", &public_ip6_addr.to_string());
}
let mut channel = PhoenixChannel::<InboundPortalMessage, ()>::connect(
let (channel, Init {}) = phoenix_channel::init::<Init, _, _>(
Secret::from(SecureUrl::from_url(url)),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
)
.await
.context("Failed to connect to the portal")?;
tracing::info!("Connected to portal, waiting for init message",);
channel.join(
"relay",
JoinMessage {
stamp_secret: stamp_secret.expose_secret().to_string(),
},
);
)
.await??;
loop {
match future::poll_fn(|cx| channel.poll(cx))
.await
.context("portal connection failed")?
{
Event::JoinedRoom { topic } if topic == "relay" => {
tracing::info!("Joined relay room on portal")
}
Event::InboundMessage {
topic,
msg: InboundPortalMessage::Init {},
} => {
tracing::info!(
"Received init message from portal on topic {topic}, starting relay activities"
);
return Ok(Some(channel));
}
other => {
tracing::debug!("Unhandled message from portal: {other:?}");
}
}
}
Ok(Some(channel))
}
#[derive(serde::Deserialize, Debug)]
struct Init {}
#[derive(serde::Serialize, PartialEq, Debug)]
struct JoinMessage {
stamp_secret: String,
}
#[derive(serde::Deserialize, PartialEq, Debug)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum InboundPortalMessage {
Init {},
}
#[cfg(debug_assertions)]
fn make_rng(seed: Option<u64>) -> StdRng {
let Some(seed) = seed else {
@@ -337,7 +307,7 @@ struct Eventloop<R> {
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
server: Server<R>,
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
channel: Option<PhoenixChannel<(), ()>>,
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
@@ -350,7 +320,7 @@ where
{
fn new(
server: Server<R>,
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
channel: Option<PhoenixChannel<(), ()>>,
public_address: IpStack,
) -> Result<Self> {
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);
@@ -506,13 +476,6 @@ where
// Priority 5: Handle portal messages
match self.channel.as_mut().map(|c| c.poll(cx)) {
Some(Poll::Ready(Ok(Event::InboundMessage {
msg: InboundPortalMessage::Init {},
..
}))) => {
tracing::warn!("Received init message during operation");
continue;
}
Some(Poll::Ready(Err(Error::Serde(e)))) => {
tracing::warn!("Failed to deserialize portal message: {e}");
continue; // This is not a hard-error, we can continue.
@@ -535,17 +498,15 @@ where
tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}");
continue;
}
Some(Poll::Ready(Ok(Event::InboundReq {
req: InboundPortalMessage::Init {},
..
}))) => {
return Poll::Ready(Err(anyhow!("Init message is not a request")));
}
Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {
tracing::debug!("Heartbeat sent to portal");
continue;
}
Some(Poll::Pending) | None => {}
Some(Poll::Ready(Ok(
Event::InboundMessage { msg: (), .. } | Event::InboundReq { req: (), .. },
)))
| Some(Poll::Pending)
| None => {}
}
return Poll::Pending;