mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(relay): connect to portal on startup (#1643)
With this PR, the relay can be configured with a WebSocket URL on startup. If given, it will attempt to connect to it and join the `relay` room with its `stamp_secret`. Once the `init` message is received, regular relay operation will begin.
This commit is contained in:
256
rust/Cargo.lock
generated
256
rust/Cargo.lock
generated
@@ -287,9 +287,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.0"
|
||||
version = "0.21.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
||||
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
@@ -514,6 +514,22 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.7"
|
||||
@@ -878,9 +894,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||
checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
@@ -1107,6 +1123,23 @@ dependencies = [
|
||||
"sha1 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
@@ -1121,9 +1154,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
|
||||
checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
@@ -1419,6 +1452,12 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
@@ -1500,9 +1539,25 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
|
||||
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
|
||||
|
||||
[[package]]
|
||||
name = "phoenix-channel"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"futures",
|
||||
"rand_core 0.6.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
@@ -1741,7 +1796,7 @@ name = "relay"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.21.0",
|
||||
"base64 0.21.2",
|
||||
"bytecodec",
|
||||
"bytes",
|
||||
"clap",
|
||||
@@ -1752,15 +1807,18 @@ dependencies = [
|
||||
"hex",
|
||||
"hex-literal",
|
||||
"once_cell",
|
||||
"phoenix-channel",
|
||||
"proptest",
|
||||
"rand",
|
||||
"redis",
|
||||
"serde",
|
||||
"sha2",
|
||||
"stun_codec",
|
||||
"test-strategy",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"webrtc",
|
||||
]
|
||||
@@ -1857,8 +1915,51 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
"log",
|
||||
"ring",
|
||||
"sct",
|
||||
"webpki",
|
||||
"sct 0.6.1",
|
||||
"webpki 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e32ca28af694bc1bbf399c33a516dbdf1c90090b8ab23c2bc24f834aa2247f5f"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-webpki",
|
||||
"sct 0.7.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1879,6 +1980,15 @@ version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3"
|
||||
dependencies = [
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@@ -1895,6 +2005,16 @@ dependencies = [
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sdp"
|
||||
version = "0.5.3"
|
||||
@@ -1921,6 +2041,29 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.17"
|
||||
@@ -2287,9 +2430,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.28.0"
|
||||
version = "1.28.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3c786bf8134e5a3a166db9b29ab8f48134739014a3eca7bc6bfa95d673b136f"
|
||||
checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bytes",
|
||||
@@ -2315,6 +2458,31 @@ dependencies = [
|
||||
"syn 2.0.15",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
|
||||
dependencies = [
|
||||
"rustls 0.21.2",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"rustls 0.21.2",
|
||||
"rustls-native-certs",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.8"
|
||||
@@ -2421,6 +2589,27 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"rustls 0.21.2",
|
||||
"sha1 0.10.5",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
"webpki 0.22.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "turn"
|
||||
version = "0.6.1"
|
||||
@@ -2507,15 +2696,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.3.1"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.1"
|
||||
@@ -2647,6 +2842,16 @@ dependencies = [
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webrtc"
|
||||
version = "0.7.2"
|
||||
@@ -2666,7 +2871,7 @@ dependencies = [
|
||||
"ring",
|
||||
"rtcp",
|
||||
"rtp",
|
||||
"rustls",
|
||||
"rustls 0.19.1",
|
||||
"sdp",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -2728,7 +2933,7 @@ dependencies = [
|
||||
"rand_core 0.6.4",
|
||||
"rcgen",
|
||||
"ring",
|
||||
"rustls",
|
||||
"rustls 0.19.1",
|
||||
"sec1",
|
||||
"serde",
|
||||
"sha1 0.10.5",
|
||||
@@ -2737,7 +2942,7 @@ dependencies = [
|
||||
"subtle",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"webpki",
|
||||
"webpki 0.21.4",
|
||||
"webrtc-util",
|
||||
"x25519-dalek",
|
||||
"x509-parser 0.13.2",
|
||||
@@ -2886,6 +3091,21 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
[workspace]
|
||||
members = ["relay"]
|
||||
members = ["relay", "phoenix-channel"]
|
||||
|
||||
18
rust/phoenix-channel/Cargo.toml
Normal file
18
rust/phoenix-channel/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "phoenix-channel"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio-tungstenite = { version = "0.19.0", features = ["rustls-tls-native-roots"] }
|
||||
futures = "0.3.28"
|
||||
base64 = "0.21.2"
|
||||
serde = { version = "1.0.163", features = ["derive"] }
|
||||
tracing = "0.1.37"
|
||||
rand_core = "0.6.4"
|
||||
url = "2.4.0"
|
||||
serde_json = "1.0.96"
|
||||
thiserror = "1.0.40"
|
||||
tokio = { version = "1.28.2", features = ["net", "time"] }
|
||||
413
rust/phoenix-channel/src/lib.rs
Normal file
413
rust/phoenix-channel/src/lib.rs
Normal file
@@ -0,0 +1,413 @@
|
||||
use std::collections::HashSet;
|
||||
use std::{fmt, marker::PhantomData, time::Duration};
|
||||
|
||||
use base64::Engine;
|
||||
use futures::{FutureExt, SinkExt, StreamExt};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{handshake::client::Request, Message},
|
||||
MaybeTlsStream, WebSocketStream,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
|
||||
stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
pending_messages: Vec<Message>,
|
||||
next_request_id: u64,
|
||||
|
||||
next_heartbeat: Pin<Box<tokio::time::Sleep>>,
|
||||
|
||||
_phantom: PhantomData<(TInboundMsg, TOutboundRes)>,
|
||||
|
||||
pending_join_requests: HashSet<OutboundRequestId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("provided URI is missing a host")]
|
||||
MissingHost,
|
||||
#[error(transparent)]
|
||||
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
|
||||
#[error("failed to serialize message")]
|
||||
Serde(#[from] serde_json::Error),
|
||||
#[error("server sent a reply without a reference")]
|
||||
MissingReplyId,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash)]
|
||||
pub struct OutboundRequestId(u64);
|
||||
|
||||
impl fmt::Display for OutboundRequestId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "OutReq-{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash)]
|
||||
pub struct InboundRequestId(u64);
|
||||
|
||||
impl fmt::Display for InboundRequestId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "InReq-{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TInboundMsg, TOutboundRes> PhoenixChannel<TInboundMsg, TOutboundRes>
|
||||
where
|
||||
TInboundMsg: DeserializeOwned,
|
||||
TOutboundRes: DeserializeOwned,
|
||||
{
|
||||
/// Creates a new [PhoenixChannel] to the given endpoint.
|
||||
///
|
||||
/// The provided URL must contain a host.
|
||||
/// Additionally, you must already provide any query parameters required for authentication.
|
||||
pub async fn connect(uri: Url, user_agent: String) -> Result<Self, Error> {
|
||||
tracing::trace!("Trying to connect to the portal...");
|
||||
|
||||
let (stream, _) = connect_async(make_request(&uri, user_agent)?).await?;
|
||||
|
||||
tracing::trace!("Successfully connected to portal");
|
||||
|
||||
Ok(Self {
|
||||
stream,
|
||||
pending_messages: vec![],
|
||||
_phantom: PhantomData,
|
||||
next_request_id: 0,
|
||||
next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)),
|
||||
pending_join_requests: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Join the provided room.
|
||||
///
|
||||
/// If successful, a [`Event::JoinedRoom`] event will be emitted.
|
||||
pub fn join(&mut self, topic: impl Into<String>, payload: impl Serialize) {
|
||||
let request_id = self.send_message(topic, EgressControlMessage::PhxJoin(payload));
|
||||
|
||||
self.pending_join_requests.insert(request_id);
|
||||
}
|
||||
|
||||
/// Send a message to a topic.
|
||||
pub fn send(&mut self, topic: impl Into<String>, message: impl Serialize) -> OutboundRequestId {
|
||||
self.send_message(topic, message)
|
||||
}
|
||||
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Result<Event<TInboundMsg, TOutboundRes>, Error>> {
|
||||
loop {
|
||||
// Priority 1: Keep local buffers small and send pending messages.
|
||||
if self.stream.poll_ready_unpin(cx).is_ready() {
|
||||
if let Some(message) = self.pending_messages.pop() {
|
||||
self.stream.start_send_unpin(message)?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Handle incoming messages.
|
||||
if let Poll::Ready(Some(message)) = self.stream.poll_next_unpin(cx)? {
|
||||
let Ok(text) = message.into_text() else {
|
||||
tracing::warn!("Received non-text message from portal");
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::trace!("Received message from portal: {text}");
|
||||
|
||||
let message = match serde_json::from_str::<PhoenixMessage<TInboundMsg, TOutboundRes>>(
|
||||
&text,
|
||||
) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to deserialize message {text}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match message.payload {
|
||||
Payload::Message(msg) => match message.reference {
|
||||
None => {
|
||||
return Poll::Ready(Ok(Event::InboundMessage {
|
||||
topic: message.topic,
|
||||
msg,
|
||||
}))
|
||||
}
|
||||
Some(reference) => {
|
||||
return Poll::Ready(Ok(Event::InboundReq {
|
||||
req_id: InboundRequestId(reference),
|
||||
req: msg,
|
||||
}))
|
||||
}
|
||||
},
|
||||
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Reason(
|
||||
reason,
|
||||
)))) => {
|
||||
return Poll::Ready(Ok(Event::ErrorResponse {
|
||||
topic: message.topic,
|
||||
req_id: OutboundRequestId(
|
||||
message.reference.ok_or(Error::MissingReplyId)?,
|
||||
),
|
||||
reason,
|
||||
}));
|
||||
}
|
||||
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::Message(
|
||||
reply,
|
||||
)))) => {
|
||||
let req_id =
|
||||
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
|
||||
|
||||
if self.pending_join_requests.remove(&req_id) {
|
||||
// For `phx_join` requests, `reply` is empty so we can safely ignore it.
|
||||
return Poll::Ready(Ok(Event::JoinedRoom {
|
||||
topic: message.topic,
|
||||
}));
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(Event::SuccessResponse {
|
||||
topic: message.topic,
|
||||
req_id,
|
||||
res: reply,
|
||||
}));
|
||||
}
|
||||
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Error(ErrorInfo::Offline))) => {
|
||||
tracing::warn!(
|
||||
"Received offline error for request {:?}",
|
||||
message.reference
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Payload::Reply(ReplyMessage::PhxReply(PhxReply::Ok(OkReply::NoMessage(
|
||||
Empty {},
|
||||
)))) => {
|
||||
tracing::trace!("Received empty reply for request {:?}", message.reference);
|
||||
continue;
|
||||
}
|
||||
Payload::Reply(ReplyMessage::PhxError(Empty {})) => {
|
||||
return Poll::Ready(Ok(Event::ErrorResponse {
|
||||
topic: message.topic,
|
||||
req_id: OutboundRequestId(
|
||||
message.reference.ok_or(Error::MissingReplyId)?,
|
||||
),
|
||||
reason: "unknown error (bad event?)".to_owned(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Handle heartbeats.
|
||||
if self.next_heartbeat.poll_unpin(cx).is_ready() {
|
||||
self.send_message("phoenix", EgressControlMessage::<()>::Heartbeat(Empty {}));
|
||||
self.next_heartbeat
|
||||
.as_mut()
|
||||
.reset(Instant::now() + HEARTBEAT_INTERVAL);
|
||||
|
||||
return Poll::Ready(Ok(Event::HeartbeatSent));
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
fn send_message(
|
||||
&mut self,
|
||||
topic: impl Into<String>,
|
||||
payload: impl Serialize,
|
||||
) -> OutboundRequestId {
|
||||
let request_id = self.fetch_add_request_id();
|
||||
|
||||
self.pending_messages.push(Message::Text(
|
||||
// We don't care about the reply type when serializing
|
||||
serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, request_id))
|
||||
.expect("we should always be able to serialize a join topic message"),
|
||||
));
|
||||
|
||||
OutboundRequestId(request_id)
|
||||
}
|
||||
|
||||
fn fetch_add_request_id(&mut self) -> u64 {
|
||||
let next_id = self.next_request_id;
|
||||
self.next_request_id += 1;
|
||||
|
||||
next_id
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Event<TInboundMsg, TOutboundRes> {
|
||||
SuccessResponse {
|
||||
topic: String,
|
||||
req_id: OutboundRequestId,
|
||||
/// The response received for an outbound request.
|
||||
res: TOutboundRes,
|
||||
},
|
||||
JoinedRoom {
|
||||
topic: String,
|
||||
},
|
||||
HeartbeatSent,
|
||||
ErrorResponse {
|
||||
topic: String,
|
||||
req_id: OutboundRequestId,
|
||||
reason: String,
|
||||
},
|
||||
/// The server sent us a message, most likely this is a broadcast to all connected clients.
|
||||
InboundMessage {
|
||||
topic: String,
|
||||
msg: TInboundMsg,
|
||||
},
|
||||
/// The server sent us a request and is expecting a response.
|
||||
InboundReq {
|
||||
req_id: InboundRequestId,
|
||||
req: TInboundMsg,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
enum Payload<T, R> {
|
||||
// We might want other type for the reply message
|
||||
// but that makes everything even more convoluted!
|
||||
// and we need to think how to make this whole mess less convoluted.
|
||||
Reply(ReplyMessage<R>),
|
||||
Message(T),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
|
||||
pub struct PhoenixMessage<T, R> {
|
||||
topic: String,
|
||||
#[serde(flatten)]
|
||||
payload: Payload<T, R>,
|
||||
#[serde(rename = "ref")]
|
||||
reference: Option<u64>,
|
||||
}
|
||||
|
||||
impl<T, R> PhoenixMessage<T, R> {
|
||||
pub fn new(topic: impl Into<String>, payload: T, reference: u64) -> Self {
|
||||
Self {
|
||||
topic: topic.into(),
|
||||
payload: Payload::Message(payload),
|
||||
reference: Some(reference),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
|
||||
fn make_request(uri: &Url, user_agent: String) -> Result<Request, Error> {
|
||||
let host = uri.host().ok_or(Error::MissingHost)?;
|
||||
let host = if let Some(port) = uri.port() {
|
||||
format!("{host}:{port}")
|
||||
} else {
|
||||
host.to_string()
|
||||
};
|
||||
|
||||
let mut r = [0u8; 16];
|
||||
OsRng.fill_bytes(&mut r);
|
||||
let key = base64::engine::general_purpose::STANDARD.encode(r);
|
||||
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.header("Host", host)
|
||||
.header("Connection", "Upgrade")
|
||||
.header("Upgrade", "websocket")
|
||||
.header("Sec-WebSocket-Version", "13")
|
||||
.header("Sec-WebSocket-Key", key)
|
||||
.header("User-Agent", user_agent)
|
||||
.uri(uri.as_str())
|
||||
.body(())
|
||||
.expect("building static request always works");
|
||||
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
// Awful hack to get serde_json to generate an empty "{}" instead of using "null"
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct Empty {}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
enum EgressControlMessage<T> {
|
||||
PhxJoin(T),
|
||||
Heartbeat(Empty),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
enum ReplyMessage<T> {
|
||||
PhxReply(PhxReply<T>),
|
||||
PhxError(Empty),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(untagged)]
|
||||
enum OkReply<T> {
|
||||
Message(T),
|
||||
NoMessage(Empty),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ErrorInfo {
|
||||
Reason(String),
|
||||
Offline,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case", tag = "status", content = "response")]
|
||||
enum PhxReply<T> {
|
||||
Ok(OkReply<T>),
|
||||
Error(ErrorInfo),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Deserialize, PartialEq, Debug)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")] // This line makes it all work.
|
||||
enum Msg {
|
||||
Shout { hello: String },
|
||||
Init {},
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialize_inbound_message() {
|
||||
let msg = r#"{
|
||||
"topic": "room:lobby",
|
||||
"ref": null,
|
||||
"payload": {
|
||||
"hello": "world"
|
||||
},
|
||||
"join_ref": null,
|
||||
"event": "shout"
|
||||
}"#;
|
||||
|
||||
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
|
||||
|
||||
assert_eq!(msg.topic, "room:lobby");
|
||||
assert_eq!(msg.reference, None);
|
||||
assert_eq!(
|
||||
msg.payload,
|
||||
Payload::Message(Msg::Shout {
|
||||
hello: "world".to_owned()
|
||||
})
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn can_deserialize_init_message() {
|
||||
let msg = r#"{"event":"init","payload":{},"ref":null,"topic":"relay"}"#;
|
||||
|
||||
let msg = serde_json::from_str::<PhoenixMessage<Msg, ()>>(msg).unwrap();
|
||||
|
||||
assert_eq!(msg.topic, "relay");
|
||||
assert_eq!(msg.reference, None);
|
||||
assert_eq!(msg.payload, Payload::Message(Msg::Init {}));
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,9 @@ proptest = { version = "1.2.0", optional = true }
|
||||
test-strategy = "0.3.0"
|
||||
derive_more = { version = "0.99.17", features = ["from"] }
|
||||
uuid = { version = "1.3.3", features = ["v4"] }
|
||||
phoenix-channel = { path = "../phoenix-channel" }
|
||||
url = "2.4.0"
|
||||
serde = { version = "1.0.163", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
webrtc = "0.7.2"
|
||||
|
||||
@@ -154,13 +154,11 @@ mod tests {
|
||||
#[test]
|
||||
fn generate_password_test_vector_elixir() {
|
||||
let expiry = systemtime_from_unix(1685984278);
|
||||
|
||||
let password = generate_password(
|
||||
"1cab293a-4032-46f4-862a-40e5d174b0d2",
|
||||
expiry,
|
||||
"uvdgKvS9GXYZ_vmv",
|
||||
);
|
||||
|
||||
assert_eq!(password, "6xUIoZ+QvxKhRasLifwfRkMXl+ETLJUsFkHlXjlHAkg")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use clap::Parser;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt, SinkExt, StreamExt};
|
||||
use futures::{future, FutureExt, SinkExt, StreamExt};
|
||||
use phoenix_channel::{Error, Event, PhoenixChannel};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use relay::{AllocationId, Command, Server, Sleep, UdpSocket};
|
||||
@@ -14,6 +15,7 @@ use std::time::SystemTime;
|
||||
use tokio::task;
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
struct Args {
|
||||
@@ -27,9 +29,17 @@ struct Args {
|
||||
/// Must not be a wildcard-address.
|
||||
#[arg(long, env)]
|
||||
listen_ip4_addr: Ipv4Addr,
|
||||
/// The websocket URL of the portal server to connect to.
|
||||
///
|
||||
/// If omitted, the relay server will start immediately, otherwise we first log on and wait for the `init` message.
|
||||
#[arg(long, env)]
|
||||
portal_ws_url: Option<Url>,
|
||||
/// Whether to allow connecting to the portal over an insecure connection.
|
||||
#[arg(long)]
|
||||
allow_insecure_ws: bool,
|
||||
/// A seed to use for all randomness operations.
|
||||
///
|
||||
/// Useful for testing and only available in debug builds.
|
||||
/// Only available in debug builds.
|
||||
#[arg(long, env)]
|
||||
rng_seed: Option<u64>,
|
||||
}
|
||||
@@ -50,17 +60,78 @@ async fn main() -> Result<()> {
|
||||
|
||||
tracing::info!("Relay auth secret: {}", hex::encode(server.auth_secret()));
|
||||
|
||||
let mut eventloop = Eventloop::new(server, args.listen_ip4_addr).await?;
|
||||
let channel = if let Some(mut portal_url) = args.portal_ws_url {
|
||||
if portal_url.scheme() == "ws" && !args.allow_insecure_ws {
|
||||
bail!("Refusing to connect to portal over insecure connection, pass --allow-insecure-ws to override")
|
||||
}
|
||||
|
||||
portal_url
|
||||
.query_pairs_mut()
|
||||
.append_pair("ipv4", &args.listen_ip4_addr.to_string());
|
||||
|
||||
let mut channel = PhoenixChannel::<InboundPortalMessage, ()>::connect(
|
||||
portal_url.clone(),
|
||||
format!("relay/{}", env!("CARGO_PKG_VERSION")),
|
||||
)
|
||||
.await
|
||||
.context("Failed to connect to the portal")?;
|
||||
|
||||
tracing::info!("Connected to portal, waiting for init message",);
|
||||
|
||||
loop {
|
||||
channel.join(
|
||||
"relay",
|
||||
JoinMessage {
|
||||
stamp_secret: hex::encode(server.auth_secret()),
|
||||
},
|
||||
);
|
||||
|
||||
let event = future::poll_fn(|cx| channel.poll(cx))
|
||||
.await
|
||||
.context("portal connection failed")?;
|
||||
|
||||
match event {
|
||||
Event::JoinedRoom { topic } if topic == "relay" => {
|
||||
tracing::info!("Joined relay room on portal")
|
||||
}
|
||||
Event::InboundMessage {
|
||||
topic,
|
||||
msg: InboundPortalMessage::Init {},
|
||||
} => {
|
||||
tracing::info!("Received init message from portal on topic {topic}, starting relay activities");
|
||||
break Some(channel);
|
||||
}
|
||||
other => {
|
||||
tracing::debug!("Unhandled message from portal: {other:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut eventloop = Eventloop::new(server, channel, args.listen_ip4_addr).await?;
|
||||
|
||||
tracing::info!("Listening for incoming traffic on UDP port 3478");
|
||||
|
||||
futures::future::poll_fn(|cx| eventloop.poll(cx))
|
||||
future::poll_fn(|cx| eventloop.poll(cx))
|
||||
.await
|
||||
.context("event loop failed")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, PartialEq, Debug)]
|
||||
struct JoinMessage {
|
||||
stamp_secret: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, PartialEq, Debug)]
|
||||
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
|
||||
enum InboundPortalMessage {
|
||||
Init {},
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn make_rng(seed: Option<u64>) -> StdRng {
|
||||
let Some(seed) = seed else {
|
||||
@@ -85,6 +156,7 @@ struct Eventloop<R> {
|
||||
ip4_socket: UdpSocket,
|
||||
listen_ip4_address: Ipv4Addr,
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
|
||||
allocations: HashMap<AllocationId, Allocation>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
@@ -106,13 +178,18 @@ impl<R> Eventloop<R>
|
||||
where
|
||||
R: Rng,
|
||||
{
|
||||
async fn new(server: Server<R>, listen_ip4_address: Ipv4Addr) -> Result<Self> {
|
||||
async fn new(
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<InboundPortalMessage, ()>>,
|
||||
listen_ip4_address: Ipv4Addr,
|
||||
) -> Result<Self> {
|
||||
let (sender, receiver) = mpsc::channel(1);
|
||||
|
||||
Ok(Self {
|
||||
ip4_socket: UdpSocket::bind((listen_ip4_address, 3478)).await?,
|
||||
listen_ip4_address,
|
||||
server,
|
||||
channel,
|
||||
allocations: Default::default(),
|
||||
relay_data_sender: sender,
|
||||
relay_data_receiver: receiver,
|
||||
@@ -240,6 +317,50 @@ where
|
||||
continue; // Handle potentially new commands.
|
||||
}
|
||||
|
||||
// Priority 7: Handle portal messages
|
||||
match self.channel.as_mut().map(|c| c.poll(cx)) {
|
||||
Some(Poll::Ready(Ok(Event::InboundMessage {
|
||||
msg: InboundPortalMessage::Init {},
|
||||
..
|
||||
}))) => {
|
||||
tracing::warn!("Received init message during operation");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Err(Error::Serde(e)))) => {
|
||||
tracing::warn!("Failed to deserialize portal message: {e}");
|
||||
continue; // This is not a hard-error, we can continue.
|
||||
}
|
||||
Some(Poll::Ready(Err(e))) => {
|
||||
return Poll::Ready(Err(anyhow!("Portal connection failed: {e}")));
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::SuccessResponse { res: (), .. }))) => {
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::JoinedRoom { topic }))) => {
|
||||
tracing::info!("Successfully joined room '{topic}'");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::ErrorResponse {
|
||||
topic,
|
||||
req_id,
|
||||
reason,
|
||||
}))) => {
|
||||
tracing::warn!("Request with ID {req_id} on topic {topic} failed: {reason}");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::InboundReq {
|
||||
req: InboundPortalMessage::Init {},
|
||||
..
|
||||
}))) => {
|
||||
return Poll::Ready(Err(anyhow!("Init message is not a request")));
|
||||
}
|
||||
Some(Poll::Ready(Ok(Event::HeartbeatSent))) => {
|
||||
tracing::debug!("Heartbeat sent to relay");
|
||||
continue;
|
||||
}
|
||||
Some(Poll::Pending) | None => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user