diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 7c3104e0d..0ac9a4971 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1171,7 +1171,6 @@ dependencies = [ "log", "mutants", "os_info", - "parking_lot", "phoenix-channel", "rand 0.8.5", "rand_core 0.6.4", @@ -1186,8 +1185,6 @@ dependencies = [ "tempfile", "thiserror", "tokio", - "tokio-stream", - "tokio-tungstenite", "tracing", "tracing-android", "url", diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 53bfd6070..bf47842fe 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -1,7 +1,7 @@ use crate::{ messages::{ BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages, - GatewayIceCandidates, IngressMessages, InitClient, RemoveResource, ReplyMessages, + GatewayIceCandidates, IngressMessages, InitClient, ReplyMessages, }, PHOENIX_TOPIC, }; @@ -189,7 +189,7 @@ where tracing::warn!(%resource_id, "Failed to add resource: {e}"); } } - IngressMessages::ResourceDeleted(RemoveResource(resource)) => { + IngressMessages::ResourceDeleted(resource) => { self.tunnel.remove_resources(&[resource]); } } diff --git a/rust/connlib/clients/shared/src/messages.rs b/rust/connlib/clients/shared/src/messages.rs index 33d7cb608..e6aa70f5d 100644 --- a/rust/connlib/clients/shared/src/messages.rs +++ b/rust/connlib/clients/shared/src/messages.rs @@ -1,27 +1,22 @@ -use std::{collections::HashSet, net::IpAddr}; - -use serde::{Deserialize, Serialize}; - use connlib_shared::messages::{ GatewayId, GatewayResponse, Interface, Key, Relay, RequestConnection, ResourceDescription, ResourceId, ReuseConnection, }; +use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, net::IpAddr}; -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[derive(Debug, PartialEq, Eq, Deserialize, Clone)] pub struct InitClient { pub interface: Interface, - #[serde(skip_serializing_if = "Vec::is_empty", default)] + #[serde(default)] pub resources: Vec, } -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +#[derive(Debug, PartialEq, Eq, Deserialize, Clone)] pub struct ConfigUpdate { pub interface: Interface, } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub struct RemoveResource(pub ResourceId); - #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] pub struct ConnectionDetails { pub relays: Vec, @@ -30,7 +25,7 @@ pub struct ConnectionDetails { pub gateway_remote_ip: IpAddr, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Connect { pub gateway_payload: GatewayResponse, pub resource_id: ResourceId, @@ -38,25 +33,16 @@ pub struct Connect { pub persistent_keepalive: u64, } -// Just because RTCSessionDescription doesn't implement partialeq -impl PartialEq for Connect { - fn eq(&self, other: &Self) -> bool { - self.resource_id == other.resource_id && self.gateway_public_key == other.gateway_public_key - } -} - -impl Eq for Connect {} - // These messages are the messages that can be received // by a client. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Clone, PartialEq)] #[serde(rename_all = "snake_case", tag = "event", content = "payload")] pub enum IngressMessages { Init(InitClient), // Resources: arrive in an orderly fashion ResourceCreatedOrUpdated(ResourceDescription), - ResourceDeleted(RemoveResource), + ResourceDeleted(ResourceId), IceCandidates(GatewayIceCandidates), @@ -64,7 +50,7 @@ pub enum IngressMessages { } /// A gateway's ice candidate message. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct BroadcastGatewayIceCandidates { /// Gateway's id the ice candidates are meant for pub gateway_ids: Vec, @@ -73,7 +59,7 @@ pub struct BroadcastGatewayIceCandidates { } /// A gateway's ice candidate message. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct GatewayIceCandidates { /// Gateway's id the ice candidates are from pub gateway_id: GatewayId, @@ -82,7 +68,7 @@ pub struct GatewayIceCandidates { } /// The replies that can arrive from the channel by a client -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Clone, PartialEq)] #[serde(untagged)] #[allow(clippy::large_enum_variant)] pub enum ReplyMessages { @@ -90,46 +76,8 @@ pub enum ReplyMessages { Connect(Connect), } -/// The totality of all messages (might have a macro in the future to derive the other types) -#[derive(Debug, Clone, PartialEq, Eq)] -#[allow(clippy::large_enum_variant)] -pub enum Messages { - Init(InitClient), - ConnectionDetails(ConnectionDetails), - Connect(Connect), - - // Resources: arrive in an orderly fashion - ResourceCreatedOrUpdated(ResourceDescription), - ResourceDeleted(RemoveResource), - - IceCandidates(GatewayIceCandidates), - - ConfigChanged(ConfigUpdate), -} - -impl From for Messages { - fn from(value: IngressMessages) -> Self { - match value { - IngressMessages::Init(m) => Self::Init(m), - IngressMessages::ResourceCreatedOrUpdated(m) => Self::ResourceCreatedOrUpdated(m), - IngressMessages::ResourceDeleted(m) => Self::ResourceDeleted(m), - IngressMessages::IceCandidates(m) => Self::IceCandidates(m), - IngressMessages::ConfigChanged(m) => Self::ConfigChanged(m), - } - } -} - -impl From for Messages { - fn from(value: ReplyMessages) -> Self { - match value { - ReplyMessages::ConnectionDetails(m) => Self::ConnectionDetails(m), - ReplyMessages::Connect(m) => Self::Connect(m), - } - } -} - // These messages can be sent from a client to a control pane -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case", tag = "event", content = "payload")] // enum_variant_names: These are the names in the portal! pub enum EgressMessages { @@ -144,19 +92,14 @@ pub enum EgressMessages { #[cfg(test)] mod test { - use std::collections::HashSet; - + use super::*; + use chrono::DateTime; use connlib_shared::messages::{ DnsServer, Interface, IpDnsServer, Relay, ResourceDescription, ResourceDescriptionCidr, ResourceDescriptionDns, Stun, Turn, }; use phoenix_channel::PhoenixMessage; - - use chrono::DateTime; - - use crate::messages::{ConnectionDetails, EgressMessages, ReplyMessages}; - - use super::{ConfigUpdate, IngressMessages, InitClient}; + use std::collections::HashSet; // TODO: request_connection tests diff --git a/rust/connlib/shared/Cargo.toml b/rust/connlib/shared/Cargo.toml index f0ee1d5b3..2697e1076 100644 --- a/rust/connlib/shared/Cargo.toml +++ b/rust/connlib/shared/Cargo.toml @@ -18,16 +18,13 @@ futures = { version = "0.3", default-features = false, features = ["std", "asyn futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] } ip_network = { version = "0.4", default-features = false, features = ["serde"] } os_info = { version = "3", default-features = false } -parking_lot = "0.12" rand = { version = "0.8", default-features = false, features = ["std"] } rand_core = { version = "0.6.4", default-features = false, features = ["std"] } resolv-conf = "0.7.0" serde = { version = "1.0", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } thiserror = { version = "1.0", default-features = false } -tokio = { version = "1.36", default-features = false, features = ["rt", "rt-multi-thread", "fs"]} -tokio-stream = { version = "0.1", features = ["time"] } -tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } +tokio = { version = "1.36", features = ["fs"] } tracing = { workspace = true } url = { version = "2.4.1", default-features = false } uuid = { version = "1.7", default-features = false, features = ["std", "v4", "serde"] } @@ -45,6 +42,7 @@ anyhow = "1.0" itertools = "0.12" tempfile = "3.10.1" mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` testing +tokio = { version = "1.36", features = ["macros", "rt"] } [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] swift-bridge = { workspace = true } diff --git a/rust/connlib/shared/src/control.rs b/rust/connlib/shared/src/control.rs deleted file mode 100644 index e31742e70..000000000 --- a/rust/connlib/shared/src/control.rs +++ /dev/null @@ -1,587 +0,0 @@ -//! Control protocol related module. -//! -//! This modules contains the logic for handling in and out messages through the control plane. -//! Handling of the message itself can be found in the other lib crates. -//! -//! Entrypoint for this module is [PhoenixChannel]. -use std::{marker::PhantomData, time::Duration}; - -use base64::Engine; -use futures::{ - channel::mpsc::{channel, Receiver, Sender}, - TryStreamExt, -}; -use futures_util::{Future, SinkExt, StreamExt, TryFutureExt}; -use rand_core::{OsRng, RngCore}; -use secrecy::Secret; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio_stream::StreamExt as _; -use tokio_tungstenite::{ - connect_async, - tungstenite::{self, handshake::client::Request}, -}; -use tungstenite::Message; - -use crate::{get_user_agent, Error, Result}; -use phoenix_channel::LoginUrl; - -const CHANNEL_SIZE: usize = 1_000; -const HEARTBEAT: Duration = Duration::from_secs(30); -const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35); - -pub type Reference = String; - -/// Main struct to interact with the control-protocol channel. -/// -/// After creating a new `PhoenixChannel` using [PhoenixChannel::new] you need to -/// use [start][PhoenixChannel::start] for the channel to do anything. -/// -/// If you want to send something through the channel you need to obtain a [PhoenixSender] through -/// [PhoenixChannel::sender], this will already clone the sender so no need to clone it after you obtain it. -/// -/// When [PhoenixChannel::start] is called a new websocket is created that will listen message from the control plane -/// based on the parameters passed on [new][PhoenixChannel::new], from then on any messages sent with a sender -/// obtained by [PhoenixChannel::sender] will be forwarded to the websocket up to the control plane. Ingress messages -/// will be passed on to the `handler` provided in [PhoenixChannel::new]. -/// -/// The future returned by [PhoenixChannel::start] will finish when the websocket closes (by an error), meaning that if you -/// `await` it, it will block until you use `close` in a [PhoenixSender], the portal close the connection or something goes wrong. -pub struct PhoenixChannel { - secret_url: Secret, - os_version_override: Option, - handler: F, - sender: Sender, - receiver: Receiver, - _phantom: PhantomData<(I, R, M)>, -} - -// This is basically the same as tungstenite does but we add some new headers (namely user-agent) -fn make_request( - secret_url: &Secret, - os_version_override: Option, -) -> Result { - use secrecy::ExposeSecret; - - let host = secret_url.expose_secret().host(); - - let mut r = [0u8; 16]; - OsRng.fill_bytes(&mut r); - let key = base64::engine::general_purpose::STANDARD.encode(r); - - let req = Request::builder() - .method("GET") - .header("Host", host) - .header("Connection", "Upgrade") - .header("Upgrade", "websocket") - .header("Sec-WebSocket-Version", "13") - .header("Sec-WebSocket-Key", key) - .header("User-Agent", get_user_agent(os_version_override)) - .uri(secret_url.expose_secret().inner().as_ref()) - .body(())?; - Ok(req) -} - -impl PhoenixChannel -where - I: DeserializeOwned, - R: DeserializeOwned, - M: From + From, - F: Fn(MessageResult, Option, String) -> Fut, - Fut: Future + Send + 'static, -{ - /// Starts the tunnel with the parameters given in [Self::new]. - /// - // (Note: we could add a generic list of messages but this is easier) - /// Additionally, you can add a list of topic to join after connection ASAP. - /// - /// See [struct-level docs][PhoenixChannel] for more info. - /// - // TODO: this is not very elegant but it was the easiest way to do reset the exponential backoff for now - /// Furthermore, it calls the given callback once it connects to the portal. - pub async fn start( - &mut self, - topics: Vec, - after_connection_ends: impl FnOnce(), - ) -> Result<()> { - tracing::trace!("Trying to connect to portal..."); - - let (ws_stream, _) = connect_async(make_request( - &self.secret_url, - self.os_version_override.clone(), - )?) - .await?; - - tracing::trace!("Successfully connected to portal"); - - let (mut write, read) = ws_stream.split(); - - let mut sender = self.sender(); - let Self { - handler, receiver, .. - } = self; - - let process_messages = tokio_stream::StreamExt::map(read.timeout(HEARTBEAT_TIMEOUT), |m| { - m.map_err(Error::from)?.map_err(Error::from) - }) - .try_for_each(|message| async { Self::message_process(handler, message).await }); - - // Would we like to do write.send_all(futures::stream(Message::text(...))) ? - // yes. - // but since write is taken by reference rust doesn't believe this future is sendable anymore - // so this works for now, since we only use it with 1 topic. - for topic in topics { - write - .send(Message::Text( - // We don't care about the reply type when serializing - serde_json::to_string(&PhoenixMessage::<_, ()>::new( - topic, - EgressControlMessage::PhxJoin(Empty {}), - None, - )) - .expect("we should always be able to serialize a join topic message"), - )) - .await?; - } - - // TODO: is Forward cancel safe? - // I would assume it is and that's the advantage over - // while let Some(item) = receiver.next().await { write.send(item) } ... - // but double check this! - // If it's not cancel safe this means an item can be consumed and never sent. - // Furthermore can this also happen if write errors out? *that* I'd assume is possible... - // What option is left? write a new future to forward items. - // For now we should never assume that an item arrived the portal because we sent it! - let send_messages = futures::StreamExt::map(receiver, Ok) - .forward(write) - .map_err(Error::from); - - let phoenix_heartbeat = tokio::spawn(async move { - let mut timer = tokio::time::interval(HEARTBEAT); - loop { - timer.tick().await; - let Ok(_) = sender - .send("phoenix", EgressControlMessage::Heartbeat(Empty {})) - .await - else { - break; - }; - } - }); - - futures_util::pin_mut!(process_messages, send_messages); - // processing messages should be quick otherwise it'd block sending messages. - // we could remove this limitation by spawning a separate task for each of these. - let result = futures::future::select(process_messages, send_messages) - .await - .factor_first() - .0; - phoenix_heartbeat.abort(); - - after_connection_ends(); - - result?; - - Ok(()) - } - - async fn message_process(handler: &F, message: tungstenite::Message) -> Result<()> { - match message.into_text() { - Ok(m_str) => match serde_json::from_str::>(&m_str) { - Ok(m) => match m.payload { - Payload::Message(payload) => { - handler(Ok(payload.into()), m.reference, m.topic).await - } - Payload::Reply(status) => match status { - // TODO: Here we should pass error info to a subscriber - Reply::Error { reason } => { - tracing::debug!("Portal error: {reason:?}"); - handler(Err(ChannelError::ErrorReply(reason)), m.reference, m.topic) - .await - } - Reply::Ok(reply) => match reply { - OkReply::NoMessage(Empty {}) => { - tracing::trace!(target: "phoenix_status", "Phoenix status message") - } - OkReply::Message(payload) => { - handler(Ok(payload.into()), m.reference, m.topic).await - } - }, - }, - Payload::Error(_) => {} - Payload::Close(Empty {}) => return Err(Error::ClosedByPortal), - Payload::Disconnect { reason: _reason } => { - // TODO: pass the _reason up to the client so it can print a pertinent user message - handler( - Err(ChannelError::ErrorMsg(Error::ClosedByPortal)), - m.reference, - m.topic, - ) - .await - } - }, - Err(e) => { - tracing::error!(message = "Error deserializing message", message_string = m_str, error = ?e); - } - }, - _ => tracing::error!("Received message that is not text"), - } - - Ok(()) - } - - /// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal. - /// - /// Note that for the sender to relay any message will need the future returned [PhoenixChannel::start] to be polled (await it), - /// and [PhoenixChannel::start] takes `&mut self`, meaning you need to get the sender before running [PhoenixChannel::start]. - pub fn sender(&self) -> PhoenixSender { - PhoenixSender { - sender: self.sender.clone(), - } - } - - /// Obtains a new sender that can be used to send message with this [PhoenixChannel] to the portal for a fixed topic. - /// - /// For more info see [PhoenixChannel::sender]. - pub fn sender_with_topic(&self, topic: String) -> PhoenixSenderWithTopic { - PhoenixSenderWithTopic { - topic, - phoenix_sender: self.sender(), - } - } - - /// Creates a new [PhoenixChannel] not started yet. - /// - /// # Parameters: - /// - `secret_url`: Portal's websocket uri - /// - `handler`: The handle that will be called for each received message. - /// - /// For more info see [struct-level docs][PhoenixChannel]. - pub fn new( - secret_url: Secret, - os_version_override: Option, - handler: F, - ) -> Self { - let (sender, receiver) = channel(CHANNEL_SIZE); - - Self { - sender, - receiver, - secret_url, - os_version_override, - handler, - _phantom: PhantomData, - } - } -} - -/// A result type that is used to communicate to the client/gateway -/// control loop the message received. -pub type MessageResult = std::result::Result; - -#[derive(Debug)] -pub enum ChannelError { - ErrorReply(ErrorReply), - ErrorMsg(Error), -} - -#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] -pub struct PhoenixMessage { - // TODO: we should use a newtype pattern for topics - topic: String, - #[serde(flatten)] - payload: Payload, - #[serde(rename = "ref")] - reference: Option, -} - -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] -#[serde(tag = "event", content = "payload")] -enum Payload { - #[serde(rename = "phx_reply")] - Reply(Reply), - #[serde(rename = "phx_error")] - Error(Empty), - #[serde(rename = "phx_close")] - Close(Empty), - #[serde(rename = "disconnect")] - Disconnect { reason: String }, - #[serde(untagged)] - Message(T), -} - -// Awful hack to get serde_json to generate an empty "{}" instead of using "null" -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] -#[serde(deny_unknown_fields)] -struct Empty {} - -#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] -#[serde(rename_all = "snake_case", tag = "status", content = "response")] -enum Reply { - Ok(OkReply), - Error { reason: ErrorReply }, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(untagged)] -enum OkReply { - Message(T), - NoMessage(Empty), -} - -/// This represents the info we have about the error -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum ErrorReply { - #[serde(rename = "unmatched topic")] - UnmatchedTopic, - TokenExpired, - NotFound, - Offline, - Disabled, - #[serde(other)] - Other, -} - -impl PhoenixMessage { - pub fn new(topic: impl Into, payload: T, reference: Option) -> Self { - Self { - topic: topic.into(), - payload: Payload::Message(payload), - reference, - } - } - - pub fn new_ok_reply( - topic: impl Into, - payload: R, - reference: impl Into>, - ) -> Self { - Self { - topic: topic.into(), - // There has to be a better way :\ - payload: Payload::Reply(Reply::Ok(OkReply::Message(payload))), - reference: reference.into(), - } - } - - pub fn new_err_reply( - topic: impl Into, - reason: ErrorReply, - reference: impl Into>, - ) -> Self { - Self { - topic: topic.into(), - // There has to be a better way :\ - payload: Payload::Reply(Reply::Error { reason }), - reference: reference.into(), - } - } -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -#[serde(rename_all = "snake_case", tag = "event", content = "payload")] -enum EgressControlMessage { - PhxJoin(Empty), - Heartbeat(Empty), -} - -/// You can use this sender to send messages through a `PhoenixChannel`. -/// -/// Messages won't be sent unless [PhoenixChannel::start] is running, internally -/// this sends messages through a future channel that are forwrarded then in [PhoenixChannel] event loop -#[derive(Clone, Debug)] -pub struct PhoenixSender { - sender: Sender, -} - -/// Like a [PhoenixSender] with a fixed topic for simplicity -/// -/// You can obtain it through [PhoenixChannel::sender_with_topic] -/// See [PhoenixSender] docs and use that if you need more control. -#[derive(Clone, Debug)] -pub struct PhoenixSenderWithTopic { - phoenix_sender: PhoenixSender, - topic: String, -} - -impl PhoenixSenderWithTopic { - /// Sends a message to the associated topic using a [PhoenixSender] - /// - /// See [PhoenixSender::send] - pub async fn send(&mut self, payload: impl Serialize) -> Result<()> { - self.phoenix_sender.send(&self.topic, payload).await - } - - /// Sends a message to the associated topic using a [PhoenixSender] also setting the ref - /// - /// See [PhoenixSender::send] - pub async fn send_with_ref( - &mut self, - payload: impl Serialize, - reference: impl ToString, - ) -> Result<()> { - self.phoenix_sender - .send_with_ref(&self.topic, payload, reference) - .await - } - - pub fn get_sender(&mut self) -> &mut PhoenixSender { - &mut self.phoenix_sender - } -} - -impl PhoenixSender { - async fn send_internal( - &mut self, - topic: impl Into, - payload: impl Serialize, - reference: Option, - ) -> Result<()> { - // We don't care about the reply type when serializing - let str = serde_json::to_string(&PhoenixMessage::<_, ()>::new(topic, payload, reference))?; - self.sender.send(Message::text(str)).await?; - Ok(()) - } - - /// Sends a message upstream to a connected [PhoenixChannel]. - /// - /// # Parameters - /// - topic: Phoenix topic - /// - payload: Message's payload - pub async fn send(&mut self, topic: impl Into, payload: impl Serialize) -> Result<()> { - self.send_internal(topic, payload, None).await - } - - /// Sends a message upstream to a connected [PhoenixChannel] using the given ref number. - /// - /// # Parameters - /// - topic: Phoenix topic - /// - payload: Message's payload - /// - reference: Reference number used in the message, if the message has a response that same number will be used - pub async fn send_with_ref( - &mut self, - topic: impl Into, - payload: impl Serialize, - reference: impl ToString, - ) -> Result<()> { - self.send_internal(topic, payload, Some(reference.to_string())) - .await - } - - /// Join a phoenix topic, meaning that after this method is invoked [PhoenixChannel] will - /// receive messages from that topic, given that upstream accepts you into the given topic. - pub async fn join_topic(&mut self, topic: impl Into) -> Result<()> { - self.send(topic, EgressControlMessage::PhxJoin(Empty {})) - .await - } - - /// Closes the [PhoenixChannel] - pub async fn close(&mut self) -> Result<()> { - self.sender.send(Message::Close(None)).await?; - self.sender.close().await?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn unmatched_topic_reply() { - let actual_reply = r#" - { - "event": "phx_reply", - "ref": "12", - "topic": "client", - "payload":{ - "status": "error", - "response":{ - "reason": "unmatched topic" - } - } - } - "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { - reason: ErrorReply::UnmatchedTopic, - }); - assert_eq!(actual_reply, expected_reply); - } - - #[test] - fn phx_close() { - let actual_reply = r#" - { - "event": "phx_close", - "ref": null, - "topic": "client", - "payload": {} - } - "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Close(Empty {}); - assert_eq!(actual_reply, expected_reply); - } - - #[test] - fn token_expired() { - let actual_reply = r#" - { - "event": "disconnect", - "ref": null, - "topic": "client", - "payload": { "reason": "token_expired" } - } - "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Disconnect { - reason: "token_expired".to_string(), - }; - assert_eq!(actual_reply, expected_reply); - } - - #[test] - fn not_found() { - let actual_reply = r#" - { - "event": "phx_reply", - "ref": null, - "topic": "client", - "payload": { - "status": "error", - "response": { - "reason": "not_found" - } - } - } - "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { - reason: ErrorReply::NotFound, - }); - assert_eq!(actual_reply, expected_reply); - } - - #[test] - fn unexpected_error_reply() { - let actual_reply = r#" - { - "event": "phx_reply", - "ref": "12", - "topic": "client", - "payload": { - "status": "error", - "response": { - "reason": "bad reply" - } - } - } - "#; - let actual_reply: Payload<(), ()> = serde_json::from_str(actual_reply).unwrap(); - let expected_reply = Payload::<(), ()>::Reply(Reply::Error { - reason: ErrorReply::Other, - }); - assert_eq!(actual_reply, expected_reply); - } -} diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 1d8a0f2a9..67ab7ca27 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -1,9 +1,7 @@ //! Error module. -use base64::{DecodeError, DecodeSliceError}; -use boringtun::noise::errors::WireGuardError; +use base64::DecodeError; use std::net::IpAddr; use thiserror::Error; -use tokio::task::JoinError; /// Unified Result type to use across connlib. pub type Result = std::result::Result; @@ -14,45 +12,9 @@ pub enum ConnlibError { /// Standard IO error. #[error(transparent)] Io(#[from] std::io::Error), - /// Standard IO error. - #[error("Failed to roll over log file: {0}")] - LogFileRollError(std::io::Error), /// Error while decoding a base64 value. #[error("There was an error while decoding a base64 value: {0}")] Base64DecodeError(#[from] DecodeError), - /// Error while decoding a base64 value from a slice. - #[error("There was an error while decoding a base64 value: {0}")] - Base64DecodeSliceError(#[from] DecodeSliceError), - /// Request error for websocket connection. - #[error("Error forming request: {0}")] - RequestError(#[from] tokio_tungstenite::tungstenite::http::Error), - /// Websocket heartbeat timedout - #[error("Websocket heartbeat timedout")] - WebsocketTimeout(#[from] tokio_stream::Elapsed), - /// Error during websocket connection. - #[error("Portal connection error: {0}")] - PortalConnectionError(#[from] tokio_tungstenite::tungstenite::error::Error), - /// Provided string was not formatted as a URL. - #[error("Badly formatted URI")] - UriError, - /// Provided an unsupported uri string. - #[error("Unsupported URI scheme: Must be http://, https://, ws:// or wss://")] - UriScheme, - /// Serde's serialize error. - #[error(transparent)] - SerializeError(#[from] serde_json::Error), - /// Error while sending through an async channelchannel. - #[error("Error sending message through an async channel")] - SendChannelError, - /// Error when trying to establish connection between peers. - #[error("Error while establishing connection between peers")] - ConnectionEstablishError, - /// Error related to wireguard protocol. - #[error("Wireguard error")] - WireguardError(WireGuardError), - /// Expected an initialized runtime but there was none. - #[error("Expected runtime to be initialized")] - NoRuntime, /// Tried to access a resource which didn't exists. #[error("Tried to access an undefined resource")] UnknownResource, @@ -62,15 +24,9 @@ pub enum ConnlibError { /// Error regarding our own control protocol. #[error("Control plane protocol error. Unexpected messages or message order.")] ControlProtocolError, - /// Error when reading system's interface - #[error("Error while reading system's interface")] - IfaceRead(std::io::Error), /// Glob for errors without a type. #[error("Other error: {0}")] Other(&'static str), - /// Invalid tunnel name - #[error("Invalid tunnel name")] - InvalidTunnelName, #[cfg(target_os = "linux")] #[error(transparent)] NetlinkError(rtnetlink::Error), @@ -85,9 +41,6 @@ pub enum ConnlibError { /// Expected file descriptor and none was found #[error("No filedescriptor")] NoFd, - /// No MTU found - #[error("No MTU found")] - NoMtu, /// A panic occurred. #[error("Connlib panicked: {0}")] Panic(String), @@ -100,29 +53,12 @@ pub enum ConnlibError { /// Received connection details that might be stale #[error("Unexpected connection details")] UnexpectedConnectionDetails, - /// Invalid phoenix channel reference - #[error("Invalid phoenix channel reply reference")] - InvalidReference, - /// Invalid packet format - #[error("Received badly formatted packet")] - BadPacket, - /// Tunnel is under load - #[error("Under load")] - UnderLoad, - /// Invalid source address for peer - #[error("Invalid source address")] - InvalidSource, /// Invalid destination for packet #[error("Invalid dest address")] InvalidDst, - /// Any parse error - #[error("parse error")] - ParseError, /// Connection is still being established, retry later #[error("Pending connection")] PendingConnection, - #[error(transparent)] - Uuid(#[from] uuid::Error), #[cfg(target_os = "windows")] #[error("Windows error: {0}")] WindowsError(#[from] windows::core::Error), @@ -135,14 +71,6 @@ pub enum ConnlibError { #[cfg(target_os = "windows")] #[error("Can't find AppData/Local folder")] CantFindLocalAppDataFolder, - #[error("Token has expired")] - TokenExpired, - #[error("Too many concurrent gateway connection requests")] - TooManyConnectionRequests, - #[error("Channel connection closed by portal")] - ClosedByPortal, - #[error(transparent)] - JoinError(#[from] JoinError), #[cfg(target_os = "linux")] #[error("Error while rewriting `/etc/resolv.conf`: {0}")] @@ -161,16 +89,6 @@ pub enum ConnlibError { PortalConnectionFailed(phoenix_channel::Error), } -impl ConnlibError { - pub fn is_http_client_error(&self) -> bool { - matches!( - self, - Self::PortalConnectionError(tokio_tungstenite::tungstenite::error::Error::Http(e)) - if e.status().is_client_error() - ) - } -} - #[cfg(target_os = "linux")] impl From for ConnlibError { fn from(err: rtnetlink::Error) -> Self { @@ -181,27 +99,3 @@ impl From for ConnlibError { } } } - -impl From for ConnlibError { - fn from(e: WireGuardError) -> Self { - ConnlibError::WireguardError(e) - } -} - -impl From<&'static str> for ConnlibError { - fn from(e: &'static str) -> Self { - ConnlibError::Other(e) - } -} - -impl From> for ConnlibError { - fn from(_: tokio::sync::mpsc::error::SendError) -> Self { - ConnlibError::SendChannelError - } -} - -impl From for ConnlibError { - fn from(_: futures::channel::mpsc::SendError) -> Self { - ConnlibError::SendChannelError - } -} diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index 7ef6236ff..5f28d80fc 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -28,8 +28,6 @@ impl ResourceId { } #[derive(Hash, Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] pub struct ClientId(Uuid); -#[derive(Hash, Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] -pub struct ActorId(Uuid); impl FromStr for ResourceId { type Err = uuid::Error; @@ -111,7 +109,7 @@ pub struct RequestConnection { pub client_payload: ClientPayload, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ClientPayload { pub ice_parameters: Offer, pub domain: Option, @@ -184,30 +182,30 @@ pub struct DomainResponse { pub address: Vec, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Answer { pub username: String, pub password: String, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Offer { pub username: String, pub password: String, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ConnectionAccepted { pub ice_parameters: Answer, pub domain_response: Option, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct ResourceAccepted { pub domain_response: DomainResponse, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub enum GatewayResponse { ConnectionAccepted(ConnectionAccepted), ResourceAccepted(ResourceAccepted), diff --git a/rust/connlib/tunnel/src/device_channel/tun_windows.rs b/rust/connlib/tunnel/src/device_channel/tun_windows.rs index f619efa91..bf89b7a76 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_windows.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_windows.rs @@ -59,7 +59,7 @@ impl Tun { // The Windows client, in `wintun_install` hashes the DLL at startup, before calling connlib, so it's unlikely for the DLL to be accidentally corrupted by the time we get here. let path = connlib_shared::windows::wintun_dll_path()?; let wintun = unsafe { wintun::load_from_path(path) }?; - let uuid = uuid::Uuid::from_str(TUNNEL_UUID)?; + let uuid = uuid::Uuid::from_str(TUNNEL_UUID).expect("static UUID to parse correctly"); let adapter = match wintun::Adapter::create(&wintun, "Firezone", TUNNEL_NAME, Some(uuid.as_u128())) { Ok(x) => x, diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index fe90e4396..f7f77d55f 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -1,8 +1,8 @@ use chrono::{serde::ts_seconds_option, DateTime, Utc}; use connlib_shared::{ messages::{ - ActorId, ClientId, ClientPayload, GatewayResponse, Interface, Peer, Relay, - ResourceDescription, ResourceId, + ClientId, ClientPayload, GatewayResponse, Interface, Peer, Relay, ResourceDescription, + ResourceId, }, Dname, }; @@ -16,37 +16,21 @@ pub struct InitGateway { pub config: Config, } -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] -pub struct Actor { - pub id: ActorId, -} - #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct Config { pub ipv4_masquerade_enabled: bool, pub ipv6_masquerade_enabled: bool, } -#[derive(Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct Client { pub id: ClientId, pub payload: ClientPayload, pub peer: Peer, } -// rtc_sdp is ignored from eq since RTCSessionDescription doesn't implement this -// this will probably be changed in the future. -impl PartialEq for Client { - fn eq(&self, other: &Self) -> bool { - self.id == other.id && self.peer == other.peer - } -} - -impl Eq for Client {} - -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RequestConnection { - pub actor: Actor, pub relays: Vec, pub resource: ResourceDescription, pub client: Client, @@ -99,7 +83,7 @@ pub struct RejectAccess { // These messages are the messages that can be received // either by a client or a gateway by the client. -#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Deserialize, Clone, PartialEq)] #[serde(rename_all = "snake_case", tag = "event", content = "payload")] pub enum IngressMessages { RequestConnection(RequestConnection),