diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 6531e602b..3b92bc898 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -131,13 +131,6 @@ impl Eventloop { } match self.portal.poll(cx)? { - Poll::Ready(phoenix_channel::Event::InboundMessage { - msg: IngressMessages::Init(_), - .. - }) => { - tracing::warn!("Received init message during operation"); - debug_assert!(false, "Received init message during operation"); - } Poll::Ready(phoenix_channel::Event::InboundMessage { msg: IngressMessages::RequestConnection(req), .. diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 03df56640..63be23b5a 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -1,6 +1,6 @@ use crate::control::ControlSignaler; use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; -use crate::messages::IngressMessages; +use crate::messages::InitGateway; use anyhow::{Context as _, Result}; use backoff::ExponentialBackoffBuilder; use boringtun::x25519::StaticSecret; @@ -59,38 +59,18 @@ async fn connect(private_key: StaticSecret, connect_url: Secret) -> R let signaler = ControlSignaler::new(control_tx); let tunnel = Arc::new(Tunnel::new(private_key, signaler, CallbackHandler).await?); - tracing::debug!("Attempting connection to portal..."); + let (channel, init) = phoenix_channel::init::( + connect_url, + get_user_agent(), + PHOENIX_TOPIC, + (), + ) + .await??; - let mut channel = - phoenix_channel::PhoenixChannel::connect(connect_url, get_user_agent()).await?; - channel.join(PHOENIX_TOPIC, ()); - - let channel = loop { - match future::poll_fn(|cx| channel.poll(cx)) - .await - .context("portal connection failed")? - { - phoenix_channel::Event::JoinedRoom { topic } if topic == PHOENIX_TOPIC => { - tracing::info!("Joined gateway room on portal") - } - phoenix_channel::Event::InboundMessage { - topic, - msg: IngressMessages::Init(init), - } => { - tracing::info!("Received init message from portal on topic {topic}"); - - tunnel - .set_interface(&init.interface) - .await - .context("Failed to set interface")?; - - break channel; - } - other => { - tracing::debug!("Unhandled message from portal: {other:?}"); - } - } - }; + tunnel + .set_interface(&init.interface) + .await + .context("Failed to set interface")?; let mut eventloop = Eventloop::new(tunnel, control_rx, channel); diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index bd48ffee4..d8ee2b06d 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -89,7 +89,6 @@ pub struct AllowAccess { // TODO: We will need to re-visit webrtc-rs #[allow(clippy::large_enum_variant)] pub enum IngressMessages { - Init(InitGateway), RequestConnection(RequestConnection), AllowAccess(AllowAccess), IceCandidates(ClientIceCandidates), @@ -135,6 +134,7 @@ pub struct ConnectionReady { #[cfg(test)] mod test { use connlib_shared::{control::PhoenixMessage, messages::Interface}; + use phoenix_channel::InitMessage; use super::{IngressMessages, InitGateway}; @@ -190,35 +190,18 @@ mod test { } #[test] fn init_phoenix_message() { - let m = PhoenixMessage::new( - "gateway:83d28051-324e-48fe-98ed-19690899b3b6", - IngressMessages::Init(InitGateway { - interface: Interface { - ipv4: "100.115.164.78".parse().unwrap(), - ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(), - upstream_dns: vec![], - }, - ipv4_masquerade_enabled: true, - ipv6_masquerade_enabled: true, - }), - None, - ); - - let message = r#"{ - "event": "init", - "payload": { - "interface": { - "ipv4": "100.115.164.78", - "ipv6": "fd00:2021:1111::2c:f6ab" - }, - "ipv4_masquerade_enabled": true, - "ipv6_masquerade_enabled": true + let m = InitMessage::Init(InitGateway { + interface: Interface { + ipv4: "100.115.164.78".parse().unwrap(), + ipv6: "fd00:2021:1111::2c:f6ab".parse().unwrap(), + upstream_dns: vec![], }, - "ref": null, - "topic": "gateway:83d28051-324e-48fe-98ed-19690899b3b6" - }"#; - let ingress_message: PhoenixMessage = - serde_json::from_str(message).unwrap(); + ipv4_masquerade_enabled: true, + ipv6_masquerade_enabled: true, + }); + + let message = r#"{"event":"init","ref":null,"topic":"gateway","payload":{"interface":{"ipv6":"fd00:2021:1111::2c:f6ab","ipv4":"100.115.164.78"},"ipv4_masquerade_enabled":true,"ipv6_masquerade_enabled":true}}"#; + let ingress_message = serde_json::from_str::>(message).unwrap(); assert_eq!(m, ingress_message); } } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 51f981ad0..fe4db33db 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::{fmt, marker::PhantomData, time::Duration}; +use std::{fmt, future, marker::PhantomData, time::Duration}; use base64::Engine; use futures::{FutureExt, SinkExt, StreamExt}; @@ -33,6 +33,63 @@ pub struct PhoenixChannel { pending_join_requests: HashSet, } +/// Creates a new [PhoenixChannel] to the given endpoint and waits for an `init` message. +/// +/// The provided URL must contain a host. +/// Additionally, you must already provide any query parameters required for authentication. +#[tracing::instrument(level = "debug", skip(payload, secret_url))] +#[allow(clippy::type_complexity)] +pub async fn init( + secret_url: Secret, + user_agent: String, + login_topic: &'static str, + payload: impl Serialize, +) -> Result< + Result<(PhoenixChannel, TInitM), UnexpectedEventDuringInit>, + Error, +> +where + TInitM: DeserializeOwned + fmt::Debug, + TInboundMsg: DeserializeOwned, + TOutboundRes: DeserializeOwned, +{ + let mut channel = + PhoenixChannel::, ()>::connect(secret_url, user_agent).await?; + channel.join(login_topic, payload); + + tracing::info!("Connected to portal, waiting for `init` message"); + + let (channel, init_message) = loop { + match future::poll_fn(|cx| channel.poll(cx)).await? { + Event::JoinedRoom { topic } if topic == login_topic => { + tracing::info!("Joined {login_topic} room on portal") + } + Event::InboundMessage { + topic, + msg: InitMessage::Init(msg), + } if topic == login_topic => { + tracing::info!("Received init message from portal"); + + break (channel, msg); + } + Event::HeartbeatSent => {} + e => return Ok(Err(UnexpectedEventDuringInit(format!("{e:?}")))), + } + }; + + Ok(Ok((channel.cast(), init_message))) +} + +#[derive(serde::Deserialize, Debug, PartialEq)] +#[serde(rename_all = "snake_case", tag = "event", content = "payload")] +pub enum InitMessage { + Init(M), +} + +#[derive(Debug, thiserror::Error)] +#[error("encountered unexpected event during init: {0}")] +pub struct UnexpectedEventDuringInit(String); + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("provided URI is missing a host")] @@ -263,6 +320,20 @@ where next_id } + + /// Cast this instance of [PhoenixChannel] to new message types. + fn cast( + self, + ) -> PhoenixChannel { + PhoenixChannel { + stream: self.stream, + pending_messages: self.pending_messages, + next_request_id: self.next_request_id, + next_heartbeat: self.next_heartbeat, + _phantom: PhantomData, + pending_join_requests: self.pending_join_requests, + } + } } #[derive(Debug)] @@ -405,7 +476,6 @@ mod tests { #[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work. enum Msg { Shout { hello: String }, - Init {}, } #[test] @@ -433,12 +503,18 @@ mod tests { } #[test] fn can_deserialize_init_message() { + #[derive(Deserialize, PartialEq, Debug)] + struct EmptyInit {} + let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#; - let msg = serde_json::from_str::>(msg).unwrap(); + let msg = serde_json::from_str::, ()>>(msg).unwrap(); assert_eq!(msg.topic, "relay"); assert_eq!(msg.reference, None); - assert_eq!(msg.payload, Payload::Message(Msg::Init {})); + assert_eq!( + msg.payload, + Payload::Message(InitMessage::Init(EmptyInit {})) + ); } } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 898d9eb43..ea8055131 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -241,7 +241,7 @@ async fn connect_to_portal( token: &SecretString, mut url: Url, stamp_secret: &SecretString, -) -> Result>> { +) -> Result>> { use secrecy::ExposeSecret; if !url.path().is_empty() { @@ -261,57 +261,27 @@ async fn connect_to_portal( .append_pair("ipv6", &public_ip6_addr.to_string()); } - let mut channel = PhoenixChannel::::connect( + let (channel, Init {}) = phoenix_channel::init::( Secret::from(SecureUrl::from_url(url)), format!("relay/{}", env!("CARGO_PKG_VERSION")), - ) - .await - .context("Failed to connect to the portal")?; - - tracing::info!("Connected to portal, waiting for init message",); - - channel.join( "relay", JoinMessage { stamp_secret: stamp_secret.expose_secret().to_string(), }, - ); + ) + .await??; - loop { - match future::poll_fn(|cx| channel.poll(cx)) - .await - .context("portal connection failed")? - { - Event::JoinedRoom { topic } if topic == "relay" => { - tracing::info!("Joined relay room on portal") - } - Event::InboundMessage { - topic, - msg: InboundPortalMessage::Init {}, - } => { - tracing::info!( - "Received init message from portal on topic {topic}, starting relay activities" - ); - return Ok(Some(channel)); - } - other => { - tracing::debug!("Unhandled message from portal: {other:?}"); - } - } - } + Ok(Some(channel)) } +#[derive(serde::Deserialize, Debug)] +struct Init {} + #[derive(serde::Serialize, PartialEq, Debug)] struct JoinMessage { stamp_secret: String, } -#[derive(serde::Deserialize, PartialEq, Debug)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum InboundPortalMessage { - Init {}, -} - #[cfg(debug_assertions)] fn make_rng(seed: Option) -> StdRng { let Some(seed) = seed else { @@ -337,7 +307,7 @@ struct Eventloop { outbound_ip4_data_sender: mpsc::Sender<(Vec, SocketAddr)>, outbound_ip6_data_sender: mpsc::Sender<(Vec, SocketAddr)>, server: Server, - channel: Option>, + channel: Option>, allocations: HashMap<(AllocationId, AddressFamily), Allocation>, relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, @@ -350,7 +320,7 @@ where { fn new( server: Server, - channel: Option>, + channel: Option>, public_address: IpStack, ) -> Result { let (relay_data_sender, relay_data_receiver) = mpsc::channel(1); @@ -506,13 +476,6 @@ where // Priority 5: Handle portal messages match self.channel.as_mut().map(|c| c.poll(cx)) { - Some(Poll::Ready(Ok(Event::InboundMessage { - msg: InboundPortalMessage::Init {}, - .. - }))) => { - tracing::warn!("Received init message during operation"); - continue; - } Some(Poll::Ready(Err(Error::Serde(e)))) => { tracing::warn!("Failed to deserialize portal message: {e}"); continue; // This is not a hard-error, we can continue. @@ -535,17 +498,15 @@ where tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}"); continue; } - Some(Poll::Ready(Ok(Event::InboundReq { - req: InboundPortalMessage::Init {}, - .. - }))) => { - return Poll::Ready(Err(anyhow!("Init message is not a request"))); - } Some(Poll::Ready(Ok(Event::HeartbeatSent))) => { tracing::debug!("Heartbeat sent to portal"); continue; } - Some(Poll::Pending) | None => {} + Some(Poll::Ready(Ok( + Event::InboundMessage { msg: (), .. } | Event::InboundReq { req: (), .. }, + ))) + | Some(Poll::Pending) + | None => {} } return Poll::Pending;