diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 7177619fa..8fc9b409e 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -287,9 +287,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.0" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64ct" @@ -514,6 +514,22 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + [[package]] name = "cpufeatures" version = "0.2.7" @@ -878,9 +894,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" dependencies = [ "percent-encoding", ] @@ -1107,6 +1123,23 @@ dependencies = [ "sha1 0.2.0", ] +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + [[package]] name = "humantime" version = "2.1.0" @@ -1121,9 +1154,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1419,6 +1452,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "overload" version = "0.1.1" @@ -1500,9 +1539,25 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "phoenix-channel" +version = "0.1.0" +dependencies = [ + "base64 0.21.2", + "futures", + "rand_core 0.6.4", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-tungstenite", + "tracing", + "url", +] [[package]] name = "pin-project-lite" @@ -1741,7 +1796,7 @@ name = "relay" version = "0.1.0" dependencies = [ "anyhow", - "base64 0.21.0", + "base64 0.21.2", "bytecodec", "bytes", "clap", @@ -1752,15 +1807,18 @@ dependencies = [ "hex", "hex-literal", "once_cell", + "phoenix-channel", "proptest", "rand", "redis", + "serde", "sha2", "stun_codec", "test-strategy", "tokio", "tracing", "tracing-subscriber", + "url", "uuid", "webrtc", ] @@ -1857,8 +1915,51 @@ dependencies = [ "base64 0.13.1", "log", "ring", - "sct", - "webpki", + "sct 0.6.1", + "webpki 0.21.4", +] + +[[package]] +name = "rustls" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e32ca28af694bc1bbf399c33a516dbdf1c90090b8ab23c2bc24f834aa2247f5f" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct 0.7.0", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +dependencies = [ + "base64 0.21.2", +] + +[[package]] +name = "rustls-webpki" +version = "0.100.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b" +dependencies = [ + "ring", + "untrusted", ] [[package]] @@ -1879,6 +1980,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -1895,6 +2005,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "sdp" version = "0.5.3" @@ -1921,6 +2041,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.17" @@ -2287,9 +2430,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.0" +version = "1.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c786bf8134e5a3a166db9b29ab8f48134739014a3eca7bc6bfa95d673b136f" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" dependencies = [ "autocfg", "bytes", @@ -2315,6 +2458,31 @@ dependencies = [ "syn 2.0.15", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.2", + "tokio", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c" +dependencies = [ + "futures-util", + "log", + "rustls 0.21.2", + "rustls-native-certs", + "tokio", + "tokio-rustls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -2421,6 +2589,27 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "tungstenite" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "rustls 0.21.2", + "sha1 0.10.5", + "thiserror", + "url", + "utf-8", + "webpki 0.22.0", +] + [[package]] name = "turn" version = "0.6.1" @@ -2507,15 +2696,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2647,6 +2842,16 @@ dependencies = [ "untrusted", ] +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "webrtc" version = "0.7.2" @@ -2666,7 +2871,7 @@ dependencies = [ "ring", "rtcp", "rtp", - "rustls", + "rustls 0.19.1", "sdp", "serde", "serde_json", @@ -2728,7 +2933,7 @@ dependencies = [ "rand_core 0.6.4", "rcgen", "ring", - "rustls", + "rustls 0.19.1", "sec1", "serde", "sha1 0.10.5", @@ -2737,7 +2942,7 @@ dependencies = [ "subtle", "thiserror", "tokio", - "webpki", + "webpki 0.21.4", "webrtc-util", "x25519-dalek", "x509-parser 0.13.2", @@ -2886,6 +3091,21 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-sys" version = "0.45.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f3c276b9a..5ff638a39 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["relay"] +members = ["relay", "phoenix-channel"] diff --git a/rust/phoenix-channel/Cargo.toml b/rust/phoenix-channel/Cargo.toml new file mode 100644 index 000000000..62576b1ab --- /dev/null +++ b/rust/phoenix-channel/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "phoenix-channel" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio-tungstenite = { version = "0.19.0", features = ["rustls-tls-native-roots"] } +futures = "0.3.28" +base64 = "0.21.2" +serde = { version = "1.0.163", features = ["derive"] } +tracing = "0.1.37" +rand_core = "0.6.4" +url = "2.4.0" +serde_json = "1.0.96" +thiserror = "1.0.40" +tokio = { version = "1.28.2", features = ["net", "time"] } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs new file mode 100644 index 000000000..e87010289 --- /dev/null +++ b/rust/phoenix-channel/src/lib.rs @@ -0,0 +1,413 @@ +use std::collections::HashSet; +use std::{fmt, marker::PhantomData, time::Duration}; + +use base64::Engine; +use futures::{FutureExt, SinkExt, StreamExt}; +use rand_core::{OsRng, RngCore}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::TcpStream; +use tokio::time::Instant; +use tokio_tungstenite::{ + connect_async, + tungstenite::{handshake::client::Request, Message}, + MaybeTlsStream, WebSocketStream, +}; +use url::Url; + +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); + +pub struct PhoenixChannel { + stream: WebSocketStream>, + pending_messages: Vec, + next_request_id: u64, + + next_heartbeat: Pin>, + + _phantom: PhantomData<(TInboundMsg, TOutboundRes)>, + + pending_join_requests: HashSet, +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("provided URI is missing a host")] + MissingHost, + #[error(transparent)] + WebSocket(#[from] tokio_tungstenite::tungstenite::Error), + #[error("failed to serialize message")] + Serde(#[from] serde_json::Error), + #[error("server sent a reply without a reference")] + MissingReplyId, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct OutboundRequestId(u64); + +impl fmt::Display for OutboundRequestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OutReq-{}", self.0) + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct InboundRequestId(u64); + +impl fmt::Display for InboundRequestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "InReq-{}", self.0) + } +} + +impl PhoenixChannel +where + TInboundMsg: DeserializeOwned, + TOutboundRes: DeserializeOwned, +{ + /// Creates a new [PhoenixChannel] to the given endpoint. + /// + /// The provided URL must contain a host. + /// Additionally, you must already provide any query parameters required for authentication. + pub async fn connect(uri: Url, user_agent: String) -> Result { + tracing::trace!("Trying to connect to the portal..."); + + let (stream, _) = connect_async(make_request(&uri, user_agent)?).await?; + + tracing::trace!("Successfully connected to portal"); + + Ok(Self { + stream, + pending_messages: vec![], + _phantom: PhantomData, + next_request_id: 0, + next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)), + pending_join_requests: Default::default(), + }) + } + + /// Join the provided room. + /// + /// If successful, a [`Event::JoinedRoom`] event will be emitted. + pub fn join(&mut self, topic: impl Into, payload: impl Serialize) { + let request_id = self.send_message(topic, EgressControlMessage::PhxJoin(payload)); + + self.pending_join_requests.insert(request_id); + } + + /// Send a message to a topic. + pub fn send(&mut self, topic: impl Into, message: impl Serialize) -> OutboundRequestId { + self.send_message(topic, message) + } + + pub fn poll( + &mut self, + cx: &mut Context, + ) -> Poll, Error>> { + loop { + // Priority 1: Keep local buffers small and send pending messages. + if self.stream.poll_ready_unpin(cx).is_ready() { + if let Some(message) = self.pending_messages.pop() { + self.stream.start_send_unpin(message)?; + continue; + } + } + + // Priority 2: Handle incoming messages. + if let Poll::Ready(Some(message)) = self.stream.poll_next_unpin(cx)? { + let Ok(text) = message.into_text() else { + tracing::warn!("Received non-text message from portal"); + continue; + }; + + tracing::trace!("Received message from portal: {text}"); + + let message = match serde_json::from_str::>( + &text, + ) { + Ok(m) => m, + Err(e) => { + tracing::warn!("Failed to deserialize message {text}: {e}"); + continue; + } + }; + + match message.payload { + Payload::Message(msg) => match message.reference { + None => { + return Poll::Ready(Ok(Event::InboundMessage { + topic: message.topic, + msg, + })) + } + Some(reference) => { + return Poll::Ready(Ok(Event::InboundReq { + req_id: InboundRequestId(reference), + req: msg, + })) + } + }, + Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Reason( + reason, + )))) => { + return Poll::Ready(Ok(Event::ErrorResponse { + topic: message.topic, + req_id: OutboundRequestId( + message.reference.ok_or(Error::MissingReplyId)?, + ), + reason, + })); + } + Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message( + reply, + )))) => { + let req_id = + OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?); + + if self.pending_join_requests.remove(&req_id) { + // For `phx_join` requests, `reply` is empty so we can safely ignore it. + return Poll::Ready(Ok(Event::JoinedRoom { + topic: message.topic, + })); + } + + return Poll::Ready(Ok(Event::SuccessResponse { + topic: message.topic, + req_id, + res: reply, + })); + } + Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Offline))) => { + tracing::warn!( + "Received offline error for request {:?}", + message.reference + ); + continue; + } + Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::NoMessage( + Empty {}, + )))) => { + tracing::trace!("Received empty reply for request {:?}", message.reference); + continue; + } + Payload::Reply(ReplyMessage::PhxError(Empty {})) => { + return Poll::Ready(Ok(Event::ErrorResponse { + topic: message.topic, + req_id: OutboundRequestId( + message.reference.ok_or(Error::MissingReplyId)?, + ), + reason: "unknown error (bad event?)".to_owned(), + })) + } + } + } + + // Priority 3: Handle heartbeats. + if self.next_heartbeat.poll_unpin(cx).is_ready() { + self.send_message("phoenix", EgressControlMessage::<()>::Heartbeat(Empty {})); + self.next_heartbeat + .as_mut() + .reset(Instant::now() + HEARTBEAT_INTERVAL); + + return Poll::Ready(Ok(Event::HeartbeatSent)); + } + + return Poll::Pending; + } + } + + fn send_message( + &mut self, + topic: impl Into, + payload: impl Serialize, + ) -> OutboundRequestId { + let request_id = self.fetch_add_request_id(); + + self.pending_messages.push(Message::Text( + // We don't care about the reply type when serializing + serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id)) + .expect("we should always be able to serialize a join topic message"), + )); + + OutboundRequestId(request_id) + } + + fn fetch_add_request_id(&mut self) -> u64 { + let next_id = self.next_request_id; + self.next_request_id += 1; + + next_id + } +} + +#[derive(Debug)] +pub enum Event { + SuccessResponse { + topic: String, + req_id: OutboundRequestId, + /// The response received for an outbound request. + res: TOutboundRes, + }, + JoinedRoom { + topic: String, + }, + HeartbeatSent, + ErrorResponse { + topic: String, + req_id: OutboundRequestId, + reason: String, + }, + /// The server sent us a message, most likely this is a broadcast to all connected clients. + InboundMessage { + topic: String, + msg: TInboundMsg, + }, + /// The server sent us a request and is expecting a response. + InboundReq { + req_id: InboundRequestId, + req: TInboundMsg, + }, +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[serde(untagged)] +enum Payload { + // We might want other type for the reply message + // but that makes everything even more convoluted! + // and we need to think how to make this whole mess less convoluted. + Reply(ReplyMessage), + Message(T), +} + +#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] +pub struct PhoenixMessage { + topic: String, + #[serde(flatten)] + payload: Payload, + #[serde(rename = "ref")] + reference: Option, +} + +impl PhoenixMessage { + pub fn new(topic: impl Into, payload: T, reference: u64) -> Self { + Self { + topic: topic.into(), + payload: Payload::Message(payload), + reference: Some(reference), + } + } +} + +// This is basically the same as tungstenite does but we add some new headers (namely user-agent) +fn make_request(uri: &Url, user_agent: String) -> Result { + let host = uri.host().ok_or(Error::MissingHost)?; + let host = if let Some(port) = uri.port() { + format!("{host}:{port}") + } else { + host.to_string() + }; + + let mut r = [0u8; 16]; + OsRng.fill_bytes(&mut r); + let key = base64::engine::general_purpose::STANDARD.encode(r); + + let req = Request::builder() + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", key) + .header("User-Agent", user_agent) + .uri(uri.as_str()) + .body(()) + .expect("building static request always works"); + + Ok(req) +} + +// Awful hack to get serde_json to generate an empty "{}" instead of using "null" +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +#[serde(deny_unknown_fields)] +struct Empty {} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "snake_case", tag = "event", content = "payload")] +enum EgressControlMessage { + PhxJoin(T), + Heartbeat(Empty), +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case", tag = "event", content = "payload")] +enum ReplyMessage { + PhxReply(PhxReply), + PhxError(Empty), +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(untagged)] +enum OkReply { + Message(T), + NoMessage(Empty), +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum ErrorInfo { + Reason(String), + Offline, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case", tag = "status", content = "response")] +enum PhxReply { + Ok(OkReply), + Error(ErrorInfo), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Deserialize, PartialEq, Debug)] + #[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work. + enum Msg { + Shout { hello: String }, + Init {}, + } + + #[test] + fn can_deserialize_inbound_message() { + let msg = r#"{ + "topic": "room:lobby", + "ref": null, + "payload": { + "hello": "world" + }, + "join_ref": null, + "event": "shout" +}"#; + + let msg = serde_json::from_str::>(msg).unwrap(); + + assert_eq!(msg.topic, "room:lobby"); + assert_eq!(msg.reference, None); + assert_eq!( + msg.payload, + Payload::Message(Msg::Shout { + hello: "world".to_owned() + }) + ); + } + #[test] + fn can_deserialize_init_message() { + let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#; + + let msg = serde_json::from_str::>(msg).unwrap(); + + assert_eq!(msg.topic, "relay"); + assert_eq!(msg.reference, None); + assert_eq!(msg.payload, Payload::Message(Msg::Init {})); + } +} diff --git a/rust/relay/Cargo.toml b/rust/relay/Cargo.toml index 82be53f6c..9b52ddb1b 100644 --- a/rust/relay/Cargo.toml +++ b/rust/relay/Cargo.toml @@ -24,6 +24,9 @@ proptest = { version = "1.2.0", optional = true } test-strategy = "0.3.0" derive_more = { version = "0.99.17", features = ["from"] } uuid = { version = "1.3.3", features = ["v4"] } +phoenix-channel = { path = "../phoenix-channel" } +url = "2.4.0" +serde = { version = "1.0.163", features = ["derive"] } [dev-dependencies] webrtc = "0.7.2" diff --git a/rust/relay/src/auth.rs b/rust/relay/src/auth.rs index 3733988fc..fac7e40fb 100644 --- a/rust/relay/src/auth.rs +++ b/rust/relay/src/auth.rs @@ -154,13 +154,11 @@ mod tests { #[test] fn generate_password_test_vector_elixir() { let expiry = systemtime_from_unix(1685984278); - let password = generate_password( "1cab293a-4032-46f4-862a-40e5d174b0d2", expiry, "uvdgKvS9GXYZ_vmv", ); - assert_eq!(password, "6xUIoZ+QvxKhRasLifwfRkMXl+ETLJUsFkHlXjlHAkg") } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 4dbc3e2bd..5f078940a 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -1,7 +1,8 @@ -use anyhow::{Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use clap::Parser; use futures::channel::mpsc; -use futures::{FutureExt, SinkExt, StreamExt}; +use futures::{future, FutureExt, SinkExt, StreamExt}; +use phoenix_channel::{Error, Event, PhoenixChannel}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use relay::{AllocationId, Command, Server, Sleep, UdpSocket}; @@ -14,6 +15,7 @@ use std::time::SystemTime; use tokio::task; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; +use url::Url; #[derive(Parser, Debug)] struct Args { @@ -27,9 +29,17 @@ struct Args { /// Must not be a wildcard-address. #[arg(long, env)] listen_ip4_addr: Ipv4Addr, + /// The websocket URL of the portal server to connect to. + /// + /// If omitted, the relay server will start immediately, otherwise we first log on and wait for the `init` message. + #[arg(long, env)] + portal_ws_url: Option, + /// Whether to allow connecting to the portal over an insecure connection. + #[arg(long)] + allow_insecure_ws: bool, /// A seed to use for all randomness operations. /// - /// Useful for testing and only available in debug builds. + /// Only available in debug builds. #[arg(long, env)] rng_seed: Option, } @@ -50,17 +60,78 @@ async fn main() -> Result<()> { tracing::info!("Relay auth secret: {}", hex::encode(server.auth_secret())); - let mut eventloop = Eventloop::new(server, args.listen_ip4_addr).await?; + let channel = if let Some(mut portal_url) = args.portal_ws_url { + if portal_url.scheme() == "ws" && !args.allow_insecure_ws { + bail!("Refusing to connect to portal over insecure connection, pass --allow-insecure-ws to override") + } + + portal_url + .query_pairs_mut() + .append_pair("ipv4", &args.listen_ip4_addr.to_string()); + + let mut channel = PhoenixChannel::::connect( + portal_url.clone(), + format!("relay/{}", env!("CARGO_PKG_VERSION")), + ) + .await + .context("Failed to connect to the portal")?; + + tracing::info!("Connected to portal, waiting for init message",); + + loop { + channel.join( + "relay", + JoinMessage { + stamp_secret: hex::encode(server.auth_secret()), + }, + ); + + let event = future::poll_fn(|cx| channel.poll(cx)) + .await + .context("portal connection failed")?; + + match event { + Event::JoinedRoom { topic } if topic == "relay" => { + tracing::info!("Joined relay room on portal") + } + Event::InboundMessage { + topic, + msg: InboundPortalMessage::Init {}, + } => { + tracing::info!("Received init message from portal on topic {topic}, starting relay activities"); + break Some(channel); + } + other => { + tracing::debug!("Unhandled message from portal: {other:?}"); + } + } + } + } else { + None + }; + + let mut eventloop = Eventloop::new(server, channel, args.listen_ip4_addr).await?; tracing::info!("Listening for incoming traffic on UDP port 3478"); - futures::future::poll_fn(|cx| eventloop.poll(cx)) + future::poll_fn(|cx| eventloop.poll(cx)) .await .context("event loop failed")?; Ok(()) } +#[derive(serde::Serialize, PartialEq, Debug)] +struct JoinMessage { + stamp_secret: String, +} + +#[derive(serde::Deserialize, PartialEq, Debug)] +#[serde(rename_all = "snake_case", tag = "event", content = "payload")] +enum InboundPortalMessage { + Init {}, +} + #[cfg(debug_assertions)] fn make_rng(seed: Option) -> StdRng { let Some(seed) = seed else { @@ -85,6 +156,7 @@ struct Eventloop { ip4_socket: UdpSocket, listen_ip4_address: Ipv4Addr, server: Server, + channel: Option>, allocations: HashMap, relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, @@ -106,13 +178,18 @@ impl Eventloop where R: Rng, { - async fn new(server: Server, listen_ip4_address: Ipv4Addr) -> Result { + async fn new( + server: Server, + channel: Option>, + listen_ip4_address: Ipv4Addr, + ) -> Result { let (sender, receiver) = mpsc::channel(1); Ok(Self { ip4_socket: UdpSocket::bind((listen_ip4_address, 3478)).await?, listen_ip4_address, server, + channel, allocations: Default::default(), relay_data_sender: sender, relay_data_receiver: receiver, @@ -240,6 +317,50 @@ where continue; // Handle potentially new commands. } + // Priority 7: Handle portal messages + match self.channel.as_mut().map(|c| c.poll(cx)) { + Some(Poll::Ready(Ok(Event::InboundMessage { + msg: InboundPortalMessage::Init {}, + .. + }))) => { + tracing::warn!("Received init message during operation"); + continue; + } + Some(Poll::Ready(Err(Error::Serde(e)))) => { + tracing::warn!("Failed to deserialize portal message: {e}"); + continue; // This is not a hard-error, we can continue. + } + Some(Poll::Ready(Err(e))) => { + return Poll::Ready(Err(anyhow!("Portal connection failed: {e}"))); + } + Some(Poll::Ready(Ok(Event::SuccessResponse { res: (), .. }))) => { + continue; + } + Some(Poll::Ready(Ok(Event::JoinedRoom { topic }))) => { + tracing::info!("Successfully joined room '{topic}'"); + continue; + } + Some(Poll::Ready(Ok(Event::ErrorResponse { + topic, + req_id, + reason, + }))) => { + tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}"); + continue; + } + Some(Poll::Ready(Ok(Event::InboundReq { + req: InboundPortalMessage::Init {}, + .. + }))) => { + return Poll::Ready(Err(anyhow!("Init message is not a request"))); + } + Some(Poll::Ready(Ok(Event::HeartbeatSent))) => { + tracing::debug!("Heartbeat sent to relay"); + continue; + } + Some(Poll::Pending) | None => {} + } + return Poll::Pending; } }