mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat: automatically rejoin channel on portal after reconnect (#3393)
In https://github.com/firezone/firezone/pull/3364, we forgot to rejoin the channel on the portal. Additionally, I found a way to detect the disconnect even more quickly.
This commit is contained in:
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -4612,6 +4612,7 @@ dependencies = [
|
||||
name = "phoenix-channel"
|
||||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"backoff",
|
||||
"base64 0.21.7",
|
||||
"futures",
|
||||
|
||||
@@ -17,7 +17,7 @@ pub const PHOENIX_TOPIC: &str = "gateway";
|
||||
|
||||
pub struct Eventloop {
|
||||
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
|
||||
portal: PhoenixChannel<IngressMessages, EgressMessages>,
|
||||
portal: PhoenixChannel<(), IngressMessages, EgressMessages>,
|
||||
|
||||
// TODO: Strongly type request reference (currently `String`)
|
||||
connection_request_tasks:
|
||||
@@ -31,7 +31,7 @@ pub struct Eventloop {
|
||||
impl Eventloop {
|
||||
pub(crate) fn new(
|
||||
tunnel: Arc<Tunnel<CallbackHandler, GatewayState>>,
|
||||
portal: PhoenixChannel<IngressMessages, EgressMessages>,
|
||||
portal: PhoenixChannel<(), IngressMessages, EgressMessages>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tunnel,
|
||||
@@ -234,6 +234,13 @@ impl Eventloop {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(phoenix_channel::Event::InboundMessage {
|
||||
msg: IngressMessages::Init(_),
|
||||
..
|
||||
}) => {
|
||||
// TODO: Handle `init` message during operation.
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ async fn run(connect_url: Url, private_key: StaticSecret) -> Result<Infallible>
|
||||
let tunnel: Arc<Tunnel<_, GatewayState>> =
|
||||
Arc::new(Tunnel::new(private_key, CallbackHandler).await?);
|
||||
|
||||
let (portal, init) = phoenix_channel::init::<InitGateway, _, _>(
|
||||
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
|
||||
Secret::new(SecureUrl::from_url(connect_url.clone())),
|
||||
get_user_agent(None),
|
||||
PHOENIX_TOPIC,
|
||||
|
||||
@@ -97,6 +97,7 @@ pub enum IngressMessages {
|
||||
RequestConnection(RequestConnection),
|
||||
AllowAccess(AllowAccess),
|
||||
IceCandidates(ClientIceCandidates),
|
||||
Init(InitGateway),
|
||||
}
|
||||
|
||||
/// A client's ice candidate message.
|
||||
|
||||
@@ -19,3 +19,4 @@ serde_json = "1.0.108"
|
||||
thiserror = "1.0.50"
|
||||
tokio = { version = "1.33.0", features = ["net", "time"] }
|
||||
backoff = "0.4.0"
|
||||
anyhow = "1"
|
||||
|
||||
@@ -24,7 +24,7 @@ const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
|
||||
// See https://github.com/firezone/firezone/issues/2158
|
||||
pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
|
||||
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
||||
state: State,
|
||||
pending_messages: Vec<Message>,
|
||||
next_request_id: u64,
|
||||
@@ -39,6 +39,9 @@ pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
|
||||
secret_url: Secret<SecureUrl>,
|
||||
user_agent: String,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
|
||||
login: &'static str,
|
||||
init_req: TInitReq,
|
||||
}
|
||||
|
||||
enum State {
|
||||
@@ -52,35 +55,40 @@ enum State {
|
||||
/// Additionally, you must already provide any query parameters required for authentication.
|
||||
#[tracing::instrument(level = "debug", skip(payload, secret_url, reconnect_backoff))]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn init<TInitM, TInboundMsg, TOutboundRes>(
|
||||
pub async fn init<TInitReq, TInitRes, TInboundMsg, TOutboundRes>(
|
||||
secret_url: Secret<SecureUrl>,
|
||||
user_agent: String,
|
||||
login_topic: &'static str,
|
||||
payload: impl Serialize,
|
||||
payload: TInitReq,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
) -> Result<
|
||||
Result<(PhoenixChannel<TInboundMsg, TOutboundRes>, TInitM), UnexpectedEventDuringInit>,
|
||||
Result<
|
||||
(
|
||||
PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>,
|
||||
TInitRes,
|
||||
),
|
||||
UnexpectedEventDuringInit,
|
||||
>,
|
||||
Error,
|
||||
>
|
||||
where
|
||||
TInitM: DeserializeOwned + fmt::Debug,
|
||||
TInitReq: Serialize + Clone,
|
||||
TInitRes: DeserializeOwned + fmt::Debug,
|
||||
TInboundMsg: DeserializeOwned,
|
||||
TOutboundRes: DeserializeOwned,
|
||||
{
|
||||
let mut channel = PhoenixChannel::<InitMessage<TInitM>, ()>::connect(
|
||||
let mut channel = PhoenixChannel::<_, InitMessage<TInitRes>, ()>::connect(
|
||||
secret_url,
|
||||
user_agent,
|
||||
login_topic,
|
||||
payload,
|
||||
reconnect_backoff,
|
||||
); // No reconnection on `init`.
|
||||
channel.join(login_topic, payload);
|
||||
);
|
||||
|
||||
tracing::info!("Connected to portal, waiting for `init` message");
|
||||
|
||||
let (channel, init_message) = loop {
|
||||
match future::poll_fn(|cx| channel.poll(cx)).await? {
|
||||
Event::JoinedRoom { topic } if topic == login_topic => {
|
||||
tracing::info!("Joined {login_topic} room on portal")
|
||||
}
|
||||
Event::InboundMessage {
|
||||
topic,
|
||||
msg: InitMessage::Init(msg),
|
||||
@@ -157,8 +165,9 @@ impl secrecy::Zeroize for SecureUrl {
|
||||
}
|
||||
}
|
||||
|
||||
impl<TInboundMsg, TOutboundRes> PhoenixChannel<TInboundMsg, TOutboundRes>
|
||||
impl<TInitReq, TInboundMsg, TOutboundRes> PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>
|
||||
where
|
||||
TInitReq: Serialize + Clone,
|
||||
TInboundMsg: DeserializeOwned,
|
||||
TOutboundRes: DeserializeOwned,
|
||||
{
|
||||
@@ -166,12 +175,16 @@ where
|
||||
///
|
||||
/// The provided URL must contain a host.
|
||||
/// Additionally, you must already provide any query parameters required for authentication.
|
||||
///
|
||||
/// Once the connection is established,
|
||||
pub fn connect(
|
||||
secret_url: Secret<SecureUrl>,
|
||||
user_agent: String,
|
||||
login: &'static str,
|
||||
init_req: TInitReq,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
) -> Self {
|
||||
Self {
|
||||
let mut phoenix_channel = Self {
|
||||
reconnect_backoff,
|
||||
secret_url: secret_url.clone(),
|
||||
user_agent: user_agent.clone(),
|
||||
@@ -185,7 +198,12 @@ where
|
||||
next_request_id: 0,
|
||||
next_heartbeat: Box::pin(tokio::time::sleep(HEARTBEAT_INTERVAL)),
|
||||
pending_join_requests: Default::default(),
|
||||
}
|
||||
login,
|
||||
init_req: init_req.clone(),
|
||||
};
|
||||
phoenix_channel.join(login, init_req);
|
||||
|
||||
phoenix_channel
|
||||
}
|
||||
|
||||
/// Join the provided room.
|
||||
@@ -216,6 +234,7 @@ where
|
||||
self.state = State::Connected(stream);
|
||||
|
||||
tracing::info!("Connected to portal");
|
||||
self.join(self.login, self.init_req.clone());
|
||||
|
||||
continue;
|
||||
}
|
||||
@@ -243,10 +262,11 @@ where
|
||||
tracing::warn!("Reconnect backoff expired");
|
||||
return Poll::Ready(Err(e));
|
||||
};
|
||||
|
||||
let secret_url = self.secret_url.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
|
||||
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal");
|
||||
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {:#}", anyhow::Error::from(e));
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
@@ -267,7 +287,7 @@ where
|
||||
match stream.start_send_unpin(message) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
self.reconnect_on_transient_error(e);
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
@@ -289,6 +309,10 @@ where
|
||||
>(&text)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) if e.is_io() || e.is_eof() => {
|
||||
self.reconnect_on_transient_error(Error::Serde(e));
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to deserialize message {text}: {e}");
|
||||
continue;
|
||||
@@ -301,7 +325,7 @@ where
|
||||
return Poll::Ready(Ok(Event::InboundMessage {
|
||||
topic: message.topic,
|
||||
msg,
|
||||
}))
|
||||
}));
|
||||
}
|
||||
Some(reference) => {
|
||||
return Poll::Ready(Ok(Event::InboundReq {
|
||||
@@ -328,6 +352,8 @@ where
|
||||
OutboundRequestId(message.reference.ok_or(Error::MissingReplyId)?);
|
||||
|
||||
if self.pending_join_requests.remove(&req_id) {
|
||||
tracing::info!("Joined {} room on portal", message.topic);
|
||||
|
||||
// For `phx_join` requests, `reply` is empty so we can safely ignore it.
|
||||
return Poll::Ready(Ok(Event::JoinedRoom {
|
||||
topic: message.topic,
|
||||
@@ -370,7 +396,7 @@ where
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
self.reconnect_on_transient_error(e);
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
continue;
|
||||
}
|
||||
_ => (),
|
||||
@@ -392,7 +418,7 @@ where
|
||||
tracing::trace!("Flushed websocket");
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
self.reconnect_on_transient_error(e);
|
||||
self.reconnect_on_transient_error(Error::WebSocket(e));
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
@@ -405,9 +431,8 @@ where
|
||||
/// Sets the channels state to [`State::Connecting`] with the given error.
|
||||
///
|
||||
/// The [`PhoenixChannel::poll`] function will handle the reconnect if appropriate for the given error.
|
||||
fn reconnect_on_transient_error(&mut self, e: tokio_tungstenite::tungstenite::Error) {
|
||||
tracing::info!("Websocket disconnected: {e:#?}");
|
||||
self.state = State::Connecting(future::ready(Err(Error::WebSocket(e))).boxed())
|
||||
fn reconnect_on_transient_error(&mut self, e: Error) {
|
||||
self.state = State::Connecting(future::ready(Err(e)).boxed())
|
||||
}
|
||||
|
||||
fn send_message(
|
||||
@@ -436,7 +461,7 @@ where
|
||||
/// Cast this instance of [PhoenixChannel] to new message types.
|
||||
fn cast<TInboundMsgNew, TOutboundResNew>(
|
||||
self,
|
||||
) -> PhoenixChannel<TInboundMsgNew, TOutboundResNew> {
|
||||
) -> PhoenixChannel<TInitReq, TInboundMsgNew, TOutboundResNew> {
|
||||
PhoenixChannel {
|
||||
state: self.state,
|
||||
pending_messages: self.pending_messages,
|
||||
@@ -447,6 +472,8 @@ where
|
||||
secret_url: self.secret_url,
|
||||
user_agent: self.user_agent,
|
||||
reconnect_backoff: self.reconnect_backoff,
|
||||
login: self.login,
|
||||
init_req: self.init_req,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,7 +246,7 @@ async fn connect_to_portal(
|
||||
token: &SecretString,
|
||||
mut url: Url,
|
||||
stamp_secret: &SecretString,
|
||||
) -> Result<Option<PhoenixChannel<(), ()>>> {
|
||||
) -> Result<Option<PhoenixChannel<JoinMessage, (), ()>>> {
|
||||
use secrecy::ExposeSecret;
|
||||
|
||||
if !url.path().is_empty() {
|
||||
@@ -266,7 +266,7 @@ async fn connect_to_portal(
|
||||
.append_pair("ipv6", &public_ip6_addr.to_string());
|
||||
}
|
||||
|
||||
let (channel, Init {}) = phoenix_channel::init::<Init, _, _>(
|
||||
let (channel, Init {}) = phoenix_channel::init::<_, Init, _, _>(
|
||||
Secret::from(SecureUrl::from_url(url)),
|
||||
format!("relay/{}", env!("CARGO_PKG_VERSION")),
|
||||
"relay",
|
||||
@@ -285,7 +285,7 @@ async fn connect_to_portal(
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
struct Init {}
|
||||
|
||||
#[derive(serde::Serialize, PartialEq, Debug)]
|
||||
#[derive(serde::Serialize, PartialEq, Debug, Clone)]
|
||||
struct JoinMessage {
|
||||
stamp_secret: String,
|
||||
}
|
||||
@@ -315,7 +315,7 @@ struct Eventloop<R> {
|
||||
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<(), ()>>,
|
||||
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
|
||||
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
@@ -328,7 +328,7 @@ where
|
||||
{
|
||||
fn new(
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<(), ()>>,
|
||||
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
|
||||
public_address: IpStack,
|
||||
) -> Result<Self> {
|
||||
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);
|
||||
|
||||
Reference in New Issue
Block a user