diff --git a/rust/connlib/phoenix-channel/src/lib.rs b/rust/connlib/phoenix-channel/src/lib.rs index 89851ce37..cd583ac61 100644 --- a/rust/connlib/phoenix-channel/src/lib.rs +++ b/rust/connlib/phoenix-channel/src/lib.rs @@ -22,11 +22,11 @@ use secrecy::{ExposeSecret, Secret}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use socket_factory::{SocketFactory, TcpSocket, TcpStream}; use std::task::{Context, Poll, Waker}; -use tokio_tungstenite::client_async_tls; use tokio_tungstenite::{ MaybeTlsStream, WebSocketStream, tungstenite::{Message, handshake::client::Request}, }; +use tokio_tungstenite::{client_async_tls, tungstenite}; use url::Url; pub use get_user_agent::get_user_agent; @@ -140,6 +140,8 @@ pub enum Error { MaxRetriesReached { final_error: String }, #[error("Failed to login with portal: {0}")] LoginFailed(ErrorReply), + #[error("Fatal IO error: {0}")] + FatalIo(io::Error), } impl Error { @@ -149,13 +151,14 @@ impl Error { Error::TokenExpired => true, Error::MaxRetriesReached { .. } => false, Error::LoginFailed(_) => false, + Error::FatalIo(_) => false, } } } #[derive(Debug)] enum InternalError { - WebSocket(tokio_tungstenite::tungstenite::Error), + WebSocket(tungstenite::Error), Serde(serde_json::Error), CloseMessage, StreamClosed, @@ -166,7 +169,7 @@ enum InternalError { impl fmt::Display for InternalError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(http)) => { + InternalError::WebSocket(tungstenite::Error::Http(http)) => { let status = http.status(); let body = http .body() @@ -200,7 +203,7 @@ impl fmt::Display for InternalError { impl std::error::Error for InternalError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(_)) => None, + InternalError::WebSocket(tungstenite::Error::Http(_)) => None, InternalError::WebSocket(e) => Some(e), InternalError::Serde(e) => Some(e), InternalError::SocketConnection(_) => None, @@ -402,11 +405,17 @@ where continue; } - Poll::Ready(Err(InternalError::WebSocket( - tokio_tungstenite::tungstenite::Error::Http(r), - ))) if r.status().is_client_error() => { + Poll::Ready(Err(InternalError::WebSocket(tungstenite::Error::Http(r)))) + if r.status().is_client_error() => + { return Poll::Ready(Err(Error::Client(r.status()))); } + // Unfortunately, the underlying error gets stringified by tungstenite so we cannot match on anything other than the string. + Poll::Ready(Err(InternalError::WebSocket(tungstenite::Error::Io(io)))) + if io.to_string().starts_with("invalid peer certificate") => + { + return Poll::Ready(Err(Error::FatalIo(io))); + } Poll::Ready(Err(e)) => { let socket_addresses = self.socket_addresses(); let host = self.host();