feat: automatically rejoin channel on portal after reconnect (#3393)

In https://github.com/firezone/firezone/pull/3364, we forgot to rejoin
the channel on the portal. Additionally, I found a way to detect the
disconnect even more quickly.
This commit is contained in:
Thomas Eizinger
2024-01-24 18:05:15 -08:00
committed by GitHub
parent 31f2f52d94
commit f9f95677d5
7 changed files with 68 additions and 31 deletions

1
rust/Cargo.lock generated
View File

@@ -4612,6 +4612,7 @@ dependencies = [
name = "phoenix-channel"
version = "1.0.0"
dependencies = [
"anyhow",
"backoff",
"base64 0.21.7",
"futures",

View File

@@ -17,7 +17,7 @@ pub const PHOENIX_TOPIC: &str = "gateway";
pub struct Eventloop {
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, EgressMessages>,
portal: PhoenixChannel<(), IngressMessages, EgressMessages>,
// TODO: Strongly type request reference (currently `String`)
connection_request_tasks:
@@ -31,7 +31,7 @@ pub struct Eventloop {
impl Eventloop {
pub(crate) fn new(
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
portal: PhoenixChannel<IngressMessages, EgressMessages>,
portal: PhoenixChannel<(), IngressMessages, EgressMessages>,
) -> Self {
Self {
tunnel,
@@ -234,6 +234,13 @@ impl Eventloop {
}
continue;
}
Poll::Ready(phoenix_channel::Event::InboundMessage {
msg: IngressMessages::Init(_),
..
}) => {
// TODO: Handle `init` message during operation.
continue;
}
_ => {}
}

View File

@@ -82,7 +82,7 @@ async fn run(connect_url: Url, private_key: StaticSecret) -> Result<Infallible>
let tunnel: Arc<Tunnel<_, GatewayState>> =
Arc::new(Tunnel::new(private_key, CallbackHandler).await?);
let (portal, init) = phoenix_channel::init::<InitGateway, _, _>(
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
Secret::new(SecureUrl::from_url(connect_url.clone())),
get_user_agent(None),
PHOENIX_TOPIC,

View File

@@ -97,6 +97,7 @@ pub enum IngressMessages {
RequestConnection(RequestConnection),
AllowAccess(AllowAccess),
IceCandidates(ClientIceCandidates),
Init(InitGateway),
}
/// A client's ice candidate message.

View File

@@ -19,3 +19,4 @@ serde_json = "1.0.108"
thiserror = "1.0.50"
tokio = { version = "1.33.0", features = ["net", "time"] }
backoff = "0.4.0"
anyhow = "1"

View File

@@ -24,7 +24,7 @@ const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
// See https://github.com/firezone/firezone/issues/2158
pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
state: State,
pending_messages: Vec<Message>,
next_request_id: u64,
@@ -39,6 +39,9 @@ pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
secret_url: Secret<SecureUrl>,
user_agent: String,
reconnect_backoff: ExponentialBackoff,
login: &'static str,
init_req: TInitReq,
}
enum State {
@@ -52,35 +55,40 @@ enum State {
/// Additionally, you must already provide any query parameters required for authentication.
#[tracing::instrument(level = "debug", skip(payload, secret_url, reconnect_backoff))]
#[allow(clippy::type_complexity)]
pub async fn init<TInitM, TInboundMsg, TOutboundRes>(
pub async fn init<TInitReq, TInitRes, TInboundMsg, TOutboundRes>(
secret_url: Secret<SecureUrl>,
user_agent: String,
login_topic: &'static str,
payload: impl Serialize,
payload: TInitReq,
reconnect_backoff: ExponentialBackoff,
) -> Result<
Result<(PhoenixChannel<TInboundMsg, TOutboundRes>, TInitM), UnexpectedEventDuringInit>,
Result<
(
PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>,
TInitRes,
),
UnexpectedEventDuringInit,
>,
Error,
>
where
TInitM: DeserializeOwned + fmt::Debug,
TInitReq: Serialize + Clone,
TInitRes: DeserializeOwned + fmt::Debug,
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
{
let mut channel = PhoenixChannel::<InitMessage<TInitM>, ()>::connect(
let mut channel = PhoenixChannel::<_, InitMessage<TInitRes>, ()>::connect(
secret_url,
user_agent,
login_topic,
payload,
reconnect_backoff,
); // No reconnection on `init`.
channel.join(login_topic, payload);
);
tracing::info!("Connected to portal, waiting for `init` message");
let (channel, init_message) = loop {
match future::poll_fn(|cx| channel.poll(cx)).await? {
Event::JoinedRoom { topic } if topic == login_topic => {
tracing::info!("Joined {login_topic} room on portal")
}
Event::InboundMessage {
topic,
msg: InitMessage::Init(msg),
@@ -157,8 +165,9 @@ impl secrecy::Zeroize for SecureUrl {
}
}
impl<TInboundMsg, TOutboundRes> PhoenixChannel<TInboundMsg, TOutboundRes>
impl<TInitReq, TInboundMsg, TOutboundRes> PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>
where
TInitReq: Serialize + Clone,
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
{
@@ -166,12 +175,16 @@ where
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
///
/// Once the connection is established,
pub fn connect(
secret_url: Secret<SecureUrl>,
user_agent: String,
login: &'static str,
init_req: TInitReq,
reconnect_backoff: ExponentialBackoff,
) -> Self {
Self {
let mut phoenix_channel = Self {
reconnect_backoff,
secret_url: secret_url.clone(),
user_agent: user_agent.clone(),
@@ -185,7 +198,12 @@ where
next_request_id: 0,
next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)),
pending_join_requests: Default::default(),
}
login,
init_req: init_req.clone(),
};
phoenix_channel.join(login, init_req);
phoenix_channel
}
/// Join the provided room.
@@ -216,6 +234,7 @@ where
self.state = State::Connected(stream);
tracing::info!("Connected to portal");
self.join(self.login, self.init_req.clone());
continue;
}
@@ -243,10 +262,11 @@ where
tracing::warn!("Reconnect backoff expired");
return Poll::Ready(Err(e));
};
let secret_url = self.secret_url.clone();
let user_agent = self.user_agent.clone();
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal");
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e));
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
@@ -267,7 +287,7 @@ where
match stream.start_send_unpin(message) {
Ok(()) => {}
Err(e) => {
self.reconnect_on_transient_error(e);
self.reconnect_on_transient_error(Error::WebSocket(e));
}
}
continue;
@@ -289,6 +309,10 @@ where
>(&text)
{
Ok(m) => m,
Err(e) if e.is_io() || e.is_eof() => {
self.reconnect_on_transient_error(Error::Serde(e));
continue;
}
Err(e) => {
tracing::warn!("Failed to deserialize message {text}: {e}");
continue;
@@ -301,7 +325,7 @@ where
return Poll::Ready(Ok(Event::InboundMessage {
topic: message.topic,
msg,
}))
}));
}
Some(reference) => {
return Poll::Ready(Ok(Event::InboundReq {
@@ -328,6 +352,8 @@ where
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
if self.pending_join_requests.remove(&req_id) {
tracing::info!("Joined {} room on portal", message.topic);
// For `phx_join` requests, `reply` is empty so we can safely ignore it.
return Poll::Ready(Ok(Event::JoinedRoom {
topic: message.topic,
@@ -370,7 +396,7 @@ where
}
}
Poll::Ready(Some(Err(e))) => {
self.reconnect_on_transient_error(e);
self.reconnect_on_transient_error(Error::WebSocket(e));
continue;
}
_ => (),
@@ -392,7 +418,7 @@ where
tracing::trace!("Flushed websocket");
}
Poll::Ready(Err(e)) => {
self.reconnect_on_transient_error(e);
self.reconnect_on_transient_error(Error::WebSocket(e));
continue;
}
Poll::Pending => {}
@@ -405,9 +431,8 @@ where
/// Sets the channels state to [`State::Connecting`] with the given error.
///
/// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error.
fn reconnect_on_transient_error(&mut self, e: tokio_tungstenite::tungstenite::Error) {
tracing::info!("Websocket disconnected: {e:#?}");
self.state = State::Connecting(future::ready(Err(Error::WebSocket(e))).boxed())
fn reconnect_on_transient_error(&mut self, e: Error) {
self.state = State::Connecting(future::ready(Err(e)).boxed())
}
fn send_message(
@@ -436,7 +461,7 @@ where
/// Cast this instance of [PhoenixChannel] to new message types.
fn cast<TInboundMsgNew, TOutboundResNew>(
self,
) -> PhoenixChannel<TInboundMsgNew, TOutboundResNew> {
) -> PhoenixChannel<TInitReq, TInboundMsgNew, TOutboundResNew> {
PhoenixChannel {
state: self.state,
pending_messages: self.pending_messages,
@@ -447,6 +472,8 @@ where
secret_url: self.secret_url,
user_agent: self.user_agent,
reconnect_backoff: self.reconnect_backoff,
login: self.login,
init_req: self.init_req,
}
}
}

View File

@@ -246,7 +246,7 @@ async fn connect_to_portal(
token: &SecretString,
mut url: Url,
stamp_secret: &SecretString,
) -> Result<Option<PhoenixChannel<(), ()>>> {
) -> Result<Option<PhoenixChannel<JoinMessage, (), ()>>> {
use secrecy::ExposeSecret;
if !url.path().is_empty() {
@@ -266,7 +266,7 @@ async fn connect_to_portal(
.append_pair("ipv6", &public_ip6_addr.to_string());
}
let (channel, Init {}) = phoenix_channel::init::<Init, _, _>(
let (channel, Init {}) = phoenix_channel::init::<_, Init, _, _>(
Secret::from(SecureUrl::from_url(url)),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
"relay",
@@ -285,7 +285,7 @@ async fn connect_to_portal(
#[derive(serde::Deserialize, Debug)]
struct Init {}
#[derive(serde::Serialize, PartialEq, Debug)]
#[derive(serde::Serialize, PartialEq, Debug, Clone)]
struct JoinMessage {
stamp_secret: String,
}
@@ -315,7 +315,7 @@ struct Eventloop<R> {
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
server: Server<R>,
channel: Option<PhoenixChannel<(), ()>>,
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
@@ -328,7 +328,7 @@ where
{
fn new(
server: Server<R>,
channel: Option<PhoenixChannel<(), ()>>,
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
public_address: IpStack,
) -> Result<Self> {
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);