feat(relay): connect to portal on startup (#1643)

With this PR, the relay can be configured with a WebSocket URL on startup. If given, it will attempt to connect to it and join the `relay` room with its `stamp_secret`. Once the `init` message is received, regular relay operation will begin.
This commit is contained in:
Thomas Eizinger
2023-06-21 21:10:39 +02:00
committed by GitHub
parent 0f594f44bc
commit 247633ed33
7 changed files with 800 additions and 27 deletions

256
rust/Cargo.lock generated
View File

@@ -287,9 +287,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.0"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
[[package]]
name = "base64ct"
@@ -514,6 +514,22 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
[[package]]
name = "core-foundation"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
version = "0.2.7"
@@ -878,9 +894,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.1.0"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652"
dependencies = [
"percent-encoding",
]
@@ -1107,6 +1123,23 @@ dependencies = [
"sha1 0.2.0",
]
[[package]]
name = "http"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "httparse"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "humantime"
version = "2.1.0"
@@ -1121,9 +1154,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.3.0"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
dependencies = [
"unicode-bidi",
"unicode-normalization",
@@ -1419,6 +1452,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "overload"
version = "0.1.1"
@@ -1500,9 +1539,25 @@ dependencies = [
[[package]]
name = "percent-encoding"
version = "2.2.0"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "phoenix-channel"
version = "0.1.0"
dependencies = [
"base64 0.21.2",
"futures",
"rand_core 0.6.4",
"serde",
"serde_json",
"thiserror",
"tokio",
"tokio-tungstenite",
"tracing",
"url",
]
[[package]]
name = "pin-project-lite"
@@ -1741,7 +1796,7 @@ name = "relay"
version = "0.1.0"
dependencies = [
"anyhow",
"base64 0.21.0",
"base64 0.21.2",
"bytecodec",
"bytes",
"clap",
@@ -1752,15 +1807,18 @@ dependencies = [
"hex",
"hex-literal",
"once_cell",
"phoenix-channel",
"proptest",
"rand",
"redis",
"serde",
"sha2",
"stun_codec",
"test-strategy",
"tokio",
"tracing",
"tracing-subscriber",
"url",
"uuid",
"webrtc",
]
@@ -1857,8 +1915,51 @@ dependencies = [
"base64 0.13.1",
"log",
"ring",
"sct",
"webpki",
"sct 0.6.1",
"webpki 0.21.4",
]
[[package]]
name = "rustls"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e32ca28af694bc1bbf399c33a516dbdf1c90090b8ab23c2bc24f834aa2247f5f"
dependencies = [
"log",
"ring",
"rustls-webpki",
"sct 0.7.0",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [
"base64 0.21.2",
]
[[package]]
name = "rustls-webpki"
version = "0.100.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
dependencies = [
"ring",
"untrusted",
]
[[package]]
@@ -1879,6 +1980,15 @@ version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "schannel"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3"
dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -1895,6 +2005,16 @@ dependencies = [
"untrusted",
]
[[package]]
name = "sct"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "sdp"
version = "0.5.3"
@@ -1921,6 +2041,29 @@ dependencies = [
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.17"
@@ -2287,9 +2430,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.28.0"
version = "1.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3c786bf8134e5a3a166db9b29ab8f48134739014a3eca7bc6bfa95d673b136f"
checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
dependencies = [
"autocfg",
"bytes",
@@ -2315,6 +2458,31 @@ dependencies = [
"syn 2.0.15",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls 0.21.2",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
dependencies = [
"futures-util",
"log",
"rustls 0.21.2",
"rustls-native-certs",
"tokio",
"tokio-rustls",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.8"
@@ -2421,6 +2589,27 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "tungstenite"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"rustls 0.21.2",
"sha1 0.10.5",
"thiserror",
"url",
"utf-8",
"webpki 0.22.0",
]
[[package]]
name = "turn"
version = "0.6.1"
@@ -2507,15 +2696,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "url"
version = "2.3.1"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8parse"
version = "0.2.1"
@@ -2647,6 +2842,16 @@ dependencies = [
"untrusted",
]
[[package]]
name = "webpki"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webrtc"
version = "0.7.2"
@@ -2666,7 +2871,7 @@ dependencies = [
"ring",
"rtcp",
"rtp",
"rustls",
"rustls 0.19.1",
"sdp",
"serde",
"serde_json",
@@ -2728,7 +2933,7 @@ dependencies = [
"rand_core 0.6.4",
"rcgen",
"ring",
"rustls",
"rustls 0.19.1",
"sec1",
"serde",
"sha1 0.10.5",
@@ -2737,7 +2942,7 @@ dependencies = [
"subtle",
"thiserror",
"tokio",
"webpki",
"webpki 0.21.4",
"webrtc-util",
"x25519-dalek",
"x509-parser 0.13.2",
@@ -2886,6 +3091,21 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.45.0"

View File

@@ -1,2 +1,2 @@
[workspace]
members = ["relay"]
members = ["relay", "phoenix-channel"]

View File

@@ -0,0 +1,18 @@
[package]
name = "phoenix-channel"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio-tungstenite = { version = "0.19.0", features = ["rustls-tls-native-roots"] }
futures = "0.3.28"
base64 = "0.21.2"
serde = { version = "1.0.163", features = ["derive"] }
tracing = "0.1.37"
rand_core = "0.6.4"
url = "2.4.0"
serde_json = "1.0.96"
thiserror = "1.0.40"
tokio = { version = "1.28.2", features = ["net", "time"] }

View File

@@ -0,0 +1,413 @@
use std::collections::HashSet;
use std::{fmt, marker::PhantomData, time::Duration};
use base64::Engine;
use futures::{FutureExt, SinkExt, StreamExt};
use rand_core::{OsRng, RngCore};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_tungstenite::{
connect_async,
tungstenite::{handshake::client::Request, Message},
MaybeTlsStream, WebSocketStream,
};
use url::Url;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
pending_messages: Vec<Message>,
next_request_id: u64,
next_heartbeat: Pin<Box<tokio::time::Sleep>>,
_phantom: PhantomData<(TInboundMsg, TOutboundRes)>,
pending_join_requests: HashSet<OutboundRequestId>,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("provided URI is missing a host")]
MissingHost,
#[error(transparent)]
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("failed to serialize message")]
Serde(#[from] serde_json::Error),
#[error("server sent a reply without a reference")]
MissingReplyId,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct OutboundRequestId(u64);
impl fmt::Display for OutboundRequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "OutReq-{}", self.0)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct InboundRequestId(u64);
impl fmt::Display for InboundRequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "InReq-{}", self.0)
}
}
impl<TInboundMsg, TOutboundRes> PhoenixChannel<TInboundMsg, TOutboundRes>
where
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
{
/// Creates a new [PhoenixChannel] to the given endpoint.
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
pub async fn connect(uri: Url, user_agent: String) -> Result<Self, Error> {
tracing::trace!("Trying to connect to the portal...");
let (stream, _) = connect_async(make_request(&uri, user_agent)?).await?;
tracing::trace!("Successfully connected to portal");
Ok(Self {
stream,
pending_messages: vec![],
_phantom: PhantomData,
next_request_id: 0,
next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)),
pending_join_requests: Default::default(),
})
}
/// Join the provided room.
///
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
pub fn join(&mut self, topic: impl Into<String>, payload: impl Serialize) {
let request_id = self.send_message(topic, EgressControlMessage::PhxJoin(payload));
self.pending_join_requests.insert(request_id);
}
/// Send a message to a topic.
pub fn send(&mut self, topic: impl Into<String>, message: impl Serialize) -> OutboundRequestId {
self.send_message(topic, message)
}
pub fn poll(
&mut self,
cx: &mut Context,
) -> Poll<Result<Event<TInboundMsg, TOutboundRes>, Error>> {
loop {
// Priority 1: Keep local buffers small and send pending messages.
if self.stream.poll_ready_unpin(cx).is_ready() {
if let Some(message) = self.pending_messages.pop() {
self.stream.start_send_unpin(message)?;
continue;
}
}
// Priority 2: Handle incoming messages.
if let Poll::Ready(Some(message)) = self.stream.poll_next_unpin(cx)? {
let Ok(text) = message.into_text() else {
tracing::warn!("Received non-text message from portal");
continue;
};
tracing::trace!("Received message from portal: {text}");
let message = match serde_json::from_str::<PhoenixMessage<TInboundMsg, TOutboundRes>>(
&text,
) {
Ok(m) => m,
Err(e) => {
tracing::warn!("Failed to deserialize message {text}: {e}");
continue;
}
};
match message.payload {
Payload::Message(msg) => match message.reference {
None => {
return Poll::Ready(Ok(Event::InboundMessage {
topic: message.topic,
msg,
}))
}
Some(reference) => {
return Poll::Ready(Ok(Event::InboundReq {
req_id: InboundRequestId(reference),
req: msg,
}))
}
},
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Reason(
reason,
)))) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
message.reference.ok_or(Error::MissingReplyId)?,
),
reason,
}));
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message(
reply,
)))) => {
let req_id =
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
if self.pending_join_requests.remove(&req_id) {
// For `phx_join` requests, `reply` is empty so we can safely ignore it.
return Poll::Ready(Ok(Event::JoinedRoom {
topic: message.topic,
}));
}
return Poll::Ready(Ok(Event::SuccessResponse {
topic: message.topic,
req_id,
res: reply,
}));
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Offline))) => {
tracing::warn!(
"Received offline error for request {:?}",
message.reference
);
continue;
}
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::NoMessage(
Empty {},
)))) => {
tracing::trace!("Received empty reply for request {:?}", message.reference);
continue;
}
Payload::Reply(ReplyMessage::PhxError(Empty {})) => {
return Poll::Ready(Ok(Event::ErrorResponse {
topic: message.topic,
req_id: OutboundRequestId(
message.reference.ok_or(Error::MissingReplyId)?,
),
reason: "unknown error (bad event?)".to_owned(),
}))
}
}
}
// Priority 3: Handle heartbeats.
if self.next_heartbeat.poll_unpin(cx).is_ready() {
self.send_message("phoenix", EgressControlMessage::<()>::Heartbeat(Empty {}));
self.next_heartbeat
.as_mut()
.reset(Instant::now() + HEARTBEAT_INTERVAL);
return Poll::Ready(Ok(Event::HeartbeatSent));
}
return Poll::Pending;
}
}
fn send_message(
&mut self,
topic: impl Into<String>,
payload: impl Serialize,
) -> OutboundRequestId {
let request_id = self.fetch_add_request_id();
self.pending_messages.push(Message::Text(
// We don't care about the reply type when serializing
serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id))
.expect("we should always be able to serialize a join topic message"),
));
OutboundRequestId(request_id)
}
fn fetch_add_request_id(&mut self) -> u64 {
let next_id = self.next_request_id;
self.next_request_id += 1;
next_id
}
}
#[derive(Debug)]
pub enum Event<TInboundMsg, TOutboundRes> {
SuccessResponse {
topic: String,
req_id: OutboundRequestId,
/// The response received for an outbound request.
res: TOutboundRes,
},
JoinedRoom {
topic: String,
},
HeartbeatSent,
ErrorResponse {
topic: String,
req_id: OutboundRequestId,
reason: String,
},
/// The server sent us a message, most likely this is a broadcast to all connected clients.
InboundMessage {
topic: String,
msg: TInboundMsg,
},
/// The server sent us a request and is expecting a response.
InboundReq {
req_id: InboundRequestId,
req: TInboundMsg,
},
}
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[serde(untagged)]
enum Payload<T, R> {
// We might want other type for the reply message
// but that makes everything even more convoluted!
// and we need to think how to make this whole mess less convoluted.
Reply(ReplyMessage<R>),
Message(T),
}
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct PhoenixMessage<T, R> {
topic: String,
#[serde(flatten)]
payload: Payload<T, R>,
#[serde(rename = "ref")]
reference: Option<u64>,
}
impl<T, R> PhoenixMessage<T, R> {
pub fn new(topic: impl Into<String>, payload: T, reference: u64) -> Self {
Self {
topic: topic.into(),
payload: Payload::Message(payload),
reference: Some(reference),
}
}
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(uri: &Url, user_agent: String) -> Result<Request, Error> {
let host = uri.host().ok_or(Error::MissingHost)?;
let host = if let Some(port) = uri.port() {
format!("{host}:{port}")
} else {
host.to_string()
};
let mut r = [0u8; 16];
OsRng.fill_bytes(&mut r);
let key = base64::engine::general_purpose::STANDARD.encode(r);
let req = Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", user_agent)
.uri(uri.as_str())
.body(())
.expect("building static request always works");
Ok(req)
}
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
#[serde(deny_unknown_fields)]
struct Empty {}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum EgressControlMessage<T> {
PhxJoin(T),
Heartbeat(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum ReplyMessage<T> {
PhxReply(PhxReply<T>),
PhxError(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(untagged)]
enum OkReply<T> {
Message(T),
NoMessage(Empty),
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum ErrorInfo {
Reason(String),
Offline,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
enum PhxReply<T> {
Ok(OkReply<T>),
Error(ErrorInfo),
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Deserialize, PartialEq, Debug)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work.
enum Msg {
Shout { hello: String },
Init {},
}
#[test]
fn can_deserialize_inbound_message() {
let msg = r#"{
"topic": "room:lobby",
"ref": null,
"payload": {
"hello": "world"
},
"join_ref": null,
"event": "shout"
}"#;
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
assert_eq!(msg.topic, "room:lobby");
assert_eq!(msg.reference, None);
assert_eq!(
msg.payload,
Payload::Message(Msg::Shout {
hello: "world".to_owned()
})
);
}
#[test]
fn can_deserialize_init_message() {
let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#;
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
assert_eq!(msg.topic, "relay");
assert_eq!(msg.reference, None);
assert_eq!(msg.payload, Payload::Message(Msg::Init {}));
}
}

View File

@@ -24,6 +24,9 @@ proptest = { version = "1.2.0", optional = true }
test-strategy = "0.3.0"
derive_more = { version = "0.99.17", features = ["from"] }
uuid = { version = "1.3.3", features = ["v4"] }
phoenix-channel = { path = "../phoenix-channel" }
url = "2.4.0"
serde = { version = "1.0.163", features = ["derive"] }
[dev-dependencies]
webrtc = "0.7.2"

View File

@@ -154,13 +154,11 @@ mod tests {
#[test]
fn generate_password_test_vector_elixir() {
let expiry = systemtime_from_unix(1685984278);
let password = generate_password(
"1cab293a-4032-46f4-862a-40e5d174b0d2",
expiry,
"uvdgKvS9GXYZ_vmv",
);
assert_eq!(password, "6xUIoZ+QvxKhRasLifwfRkMXl+ETLJUsFkHlXjlHAkg")
}

View File

@@ -1,7 +1,8 @@
use anyhow::{Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use futures::channel::mpsc;
use futures::{FutureExt, SinkExt, StreamExt};
use futures::{future, FutureExt, SinkExt, StreamExt};
use phoenix_channel::{Error, Event, PhoenixChannel};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use relay::{AllocationId, Command, Server, Sleep, UdpSocket};
@@ -14,6 +15,7 @@ use std::time::SystemTime;
use tokio::task;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::EnvFilter;
use url::Url;
#[derive(Parser, Debug)]
struct Args {
@@ -27,9 +29,17 @@ struct Args {
/// Must not be a wildcard-address.
#[arg(long, env)]
listen_ip4_addr: Ipv4Addr,
/// The websocket URL of the portal server to connect to.
///
/// If omitted, the relay server will start immediately, otherwise we first log on and wait for the `init` message.
#[arg(long, env)]
portal_ws_url: Option<Url>,
/// Whether to allow connecting to the portal over an insecure connection.
#[arg(long)]
allow_insecure_ws: bool,
/// A seed to use for all randomness operations.
///
/// Useful for testing and only available in debug builds.
/// Only available in debug builds.
#[arg(long, env)]
rng_seed: Option<u64>,
}
@@ -50,17 +60,78 @@ async fn main() -> Result<()> {
tracing::info!("Relay auth secret: {}", hex::encode(server.auth_secret()));
let mut eventloop = Eventloop::new(server, args.listen_ip4_addr).await?;
let channel = if let Some(mut portal_url) = args.portal_ws_url {
if portal_url.scheme() == "ws" && !args.allow_insecure_ws {
bail!("Refusing to connect to portal over insecure connection, pass --allow-insecure-ws to override")
}
portal_url
.query_pairs_mut()
.append_pair("ipv4", &args.listen_ip4_addr.to_string());
let mut channel = PhoenixChannel::<InboundPortalMessage, ()>::connect(
portal_url.clone(),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
)
.await
.context("Failed to connect to the portal")?;
tracing::info!("Connected to portal, waiting for init message",);
loop {
channel.join(
"relay",
JoinMessage {
stamp_secret: hex::encode(server.auth_secret()),
},
);
let event = future::poll_fn(|cx| channel.poll(cx))
.await
.context("portal connection failed")?;
match event {
Event::JoinedRoom { topic } if topic == "relay" => {
tracing::info!("Joined relay room on portal")
}
Event::InboundMessage {
topic,
msg: InboundPortalMessage::Init {},
} => {
tracing::info!("Received init message from portal on topic {topic}, starting relay activities");
break Some(channel);
}
other => {
tracing::debug!("Unhandled message from portal: {other:?}");
}
}
}
} else {
None
};
let mut eventloop = Eventloop::new(server, channel, args.listen_ip4_addr).await?;
tracing::info!("Listening for incoming traffic on UDP port 3478");
futures::future::poll_fn(|cx| eventloop.poll(cx))
future::poll_fn(|cx| eventloop.poll(cx))
.await
.context("event loop failed")?;
Ok(())
}
#[derive(serde::Serialize, PartialEq, Debug)]
struct JoinMessage {
stamp_secret: String,
}
#[derive(serde::Deserialize, PartialEq, Debug)]
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
enum InboundPortalMessage {
Init {},
}
#[cfg(debug_assertions)]
fn make_rng(seed: Option<u64>) -> StdRng {
let Some(seed) = seed else {
@@ -85,6 +156,7 @@ struct Eventloop<R> {
ip4_socket: UdpSocket,
listen_ip4_address: Ipv4Addr,
server: Server<R>,
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
allocations: HashMap<AllocationId, Allocation>,
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
@@ -106,13 +178,18 @@ impl<R> Eventloop<R>
where
R: Rng,
{
async fn new(server: Server<R>, listen_ip4_address: Ipv4Addr) -> Result<Self> {
async fn new(
server: Server<R>,
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
listen_ip4_address: Ipv4Addr,
) -> Result<Self> {
let (sender, receiver) = mpsc::channel(1);
Ok(Self {
ip4_socket: UdpSocket::bind((listen_ip4_address, 3478)).await?,
listen_ip4_address,
server,
channel,
allocations: Default::default(),
relay_data_sender: sender,
relay_data_receiver: receiver,
@@ -240,6 +317,50 @@ where
continue; // Handle potentially new commands.
}
// Priority 7: Handle portal messages
match self.channel.as_mut().map(|c| c.poll(cx)) {
Some(Poll::Ready(Ok(Event::InboundMessage {
msg: InboundPortalMessage::Init {},
..
}))) => {
tracing::warn!("Received init message during operation");
continue;
}
Some(Poll::Ready(Err(Error::Serde(e)))) => {
tracing::warn!("Failed to deserialize portal message: {e}");
continue; // This is not a hard-error, we can continue.
}
Some(Poll::Ready(Err(e))) => {
return Poll::Ready(Err(anyhow!("Portal connection failed: {e}")));
}
Some(Poll::Ready(Ok(Event::SuccessResponse { res: (), .. }))) => {
continue;
}
Some(Poll::Ready(Ok(Event::JoinedRoom { topic }))) => {
tracing::info!("Successfully joined room '{topic}'");
continue;
}
Some(Poll::Ready(Ok(Event::ErrorResponse {
topic,
req_id,
reason,
}))) => {
tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}");
continue;
}
Some(Poll::Ready(Ok(Event::InboundReq {
req: InboundPortalMessage::Init {},
..
}))) => {
return Poll::Ready(Err(anyhow!("Init message is not a request")));
}
Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {
tracing::debug!("Heartbeat sent to relay");
continue;
}
Some(Poll::Pending) | None => {}
}
return Poll::Pending;
}
}