mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(phoenix-channel): fail on missing heartbeat after 5s (#4296)
This PR fixes a bug and adds a missing feature to `phoenix-channel`. 1. Previously, we used to erroneously reset the heartbeat state on all sorts of empty replies, not just the specific one from the heartbeat. 2. We only failed on missing heartbeats when it was time to send the next one. With this PR, we correct the first bug and add a dedicated timeout of 5s for the heartbeat reply.
This commit is contained in:
@@ -1,81 +1,93 @@
|
||||
use crate::{EgressControlMessage, OutboundRequestId};
|
||||
use crate::OutboundRequestId;
|
||||
use futures::FutureExt;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
task::{ready, Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time::MissedTickBehavior;
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
|
||||
pub const INTERVAL: Duration = Duration::from_secs(30);
|
||||
pub const TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
pub struct Heartbeat {
|
||||
/// When to send the next heartbeat.
|
||||
interval: Pin<Box<tokio::time::Interval>>,
|
||||
|
||||
timeout: Duration,
|
||||
|
||||
/// The ID of our heatbeat if we haven't received a reply yet.
|
||||
id: Option<OutboundRequestId>,
|
||||
pending: Option<(OutboundRequestId, Pin<Box<tokio::time::Sleep>>)>,
|
||||
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Heartbeat {
|
||||
pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool {
|
||||
let Some(pending) = self.id.take() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if pending != id {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.id = None;
|
||||
true
|
||||
}
|
||||
|
||||
pub fn set_id(&mut self, id: OutboundRequestId) {
|
||||
self.id = Some(id);
|
||||
}
|
||||
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Result<EgressControlMessage<()>, MissedLastHeartbeat>> {
|
||||
ready!(self.interval.poll_tick(cx));
|
||||
|
||||
if self.id.is_some() {
|
||||
self.id = None;
|
||||
return Poll::Ready(Err(MissedLastHeartbeat {}));
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(EgressControlMessage::Heartbeat(crate::Empty {})))
|
||||
}
|
||||
|
||||
fn new(interval: Duration) -> Self {
|
||||
pub fn new(interval: Duration, timeout: Duration, next_request_id: Arc<AtomicU64>) -> Self {
|
||||
let mut interval = tokio::time::interval(interval);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
Self {
|
||||
interval: Box::pin(interval),
|
||||
id: Default::default(),
|
||||
pending: Default::default(),
|
||||
next_request_id,
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool {
|
||||
match self.pending.as_ref() {
|
||||
Some((pending, timeout)) if pending == &id && !dbg!(timeout.is_elapsed()) => {
|
||||
self.pending = None;
|
||||
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context,
|
||||
) -> Poll<Result<OutboundRequestId, MissedLastHeartbeat>> {
|
||||
if let Some((_, timeout)) = self.pending.as_mut() {
|
||||
ready!(timeout.poll_unpin(cx));
|
||||
self.pending = None;
|
||||
return Poll::Ready(Err(MissedLastHeartbeat {}));
|
||||
}
|
||||
|
||||
ready!(self.interval.poll_tick(cx));
|
||||
|
||||
let next_id = self
|
||||
.next_request_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
self.pending = Some((
|
||||
OutboundRequestId(next_id),
|
||||
Box::pin(tokio::time::sleep(self.timeout)),
|
||||
));
|
||||
|
||||
Poll::Ready(Ok(OutboundRequestId(next_id)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MissedLastHeartbeat {}
|
||||
|
||||
impl Default for Heartbeat {
|
||||
fn default() -> Self {
|
||||
Self::new(HEARTBEAT_INTERVAL)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::future::Either;
|
||||
use std::{future::poll_fn, time::Instant};
|
||||
|
||||
const INTERVAL: Duration = Duration::from_millis(30);
|
||||
const TIMEOUT: Duration = Duration::from_millis(5);
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_heartbeat_after_interval() {
|
||||
let mut heartbeat = Heartbeat::new(Duration::from_millis(30));
|
||||
let _ = poll_fn(|cx| heartbeat.poll(cx)).await; // Tick once at startup.
|
||||
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
|
||||
let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); // Tick once at startup.
|
||||
heartbeat.maybe_handle_reply(id);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
@@ -84,15 +96,25 @@ mod tests {
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(elapsed >= Duration::from_millis(10));
|
||||
assert!(elapsed >= INTERVAL);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fails_if_response_is_not_provided_before_next_poll() {
|
||||
let mut heartbeat = Heartbeat::new(Duration::from_millis(10));
|
||||
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
|
||||
|
||||
let _ = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
heartbeat.set_id(OutboundRequestId::for_test(1));
|
||||
|
||||
let result = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ignores_other_ids() {
|
||||
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
|
||||
|
||||
let _ = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
heartbeat.maybe_handle_reply(OutboundRequestId::for_test(2));
|
||||
|
||||
let result = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
assert!(result.is_err());
|
||||
@@ -100,13 +122,34 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn succeeds_if_response_is_provided_inbetween_polls() {
|
||||
let mut heartbeat = Heartbeat::new(Duration::from_millis(10));
|
||||
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
|
||||
|
||||
let _ = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
heartbeat.set_id(OutboundRequestId::for_test(1));
|
||||
heartbeat.maybe_handle_reply(OutboundRequestId::for_test(1));
|
||||
let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap();
|
||||
heartbeat.maybe_handle_reply(id);
|
||||
|
||||
let result = poll_fn(|cx| heartbeat.poll(cx)).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fails_if_not_provided_within_timeout() {
|
||||
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
|
||||
|
||||
let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap();
|
||||
|
||||
let select = futures::future::select(
|
||||
tokio::time::sleep(TIMEOUT * 2).boxed(),
|
||||
poll_fn(|cx| heartbeat.poll(cx)),
|
||||
)
|
||||
.await;
|
||||
|
||||
match select {
|
||||
Either::Left(((), _)) => panic!("timeout should not resolve"),
|
||||
Either::Right((Ok(_), _)) => panic!("heartbeat should fail and not issue new ID"),
|
||||
Either::Right((Err(_), _)) => {}
|
||||
}
|
||||
|
||||
let handled = heartbeat.maybe_handle_reply(id);
|
||||
assert!(!handled);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ use tokio_tungstenite::{
|
||||
};
|
||||
|
||||
pub use login_url::{LoginUrl, LoginUrlError};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
|
||||
// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
|
||||
// See https://github.com/firezone/firezone/issues/2158
|
||||
@@ -30,7 +32,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
||||
state: State,
|
||||
waker: Option<Waker>,
|
||||
pending_messages: VecDeque<String>,
|
||||
next_request_id: u64,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
|
||||
heartbeat: Heartbeat,
|
||||
|
||||
@@ -208,6 +210,8 @@ where
|
||||
init_req: TInitReq,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
) -> Self {
|
||||
let next_request_id = Arc::new(AtomicU64::new(0));
|
||||
|
||||
Self {
|
||||
reconnect_backoff,
|
||||
url: url.clone(),
|
||||
@@ -222,8 +226,12 @@ where
|
||||
waker: None,
|
||||
pending_messages: Default::default(),
|
||||
_phantom: PhantomData,
|
||||
next_request_id: 0,
|
||||
heartbeat: Default::default(),
|
||||
heartbeat: Heartbeat::new(
|
||||
heartbeat::INTERVAL,
|
||||
heartbeat::TIMEOUT,
|
||||
next_request_id.clone(),
|
||||
),
|
||||
next_request_id,
|
||||
pending_join_requests: Default::default(),
|
||||
login,
|
||||
init_req: init_req.clone(),
|
||||
@@ -447,10 +455,12 @@ where
|
||||
|
||||
// Priority 3: Handle heartbeats.
|
||||
match self.heartbeat.poll(cx) {
|
||||
Poll::Ready(Ok(msg)) => {
|
||||
let (id, msg) = self.make_message("phoenix", msg);
|
||||
self.pending_messages.push_back(msg);
|
||||
self.heartbeat.set_id(id);
|
||||
Poll::Ready(Ok(id)) => {
|
||||
self.pending_messages.push_back(serialize_msg(
|
||||
"phoenix",
|
||||
EgressControlMessage::<()>::Heartbeat(Empty {}),
|
||||
id.copy(),
|
||||
));
|
||||
|
||||
return Poll::Ready(Ok(Event::HeartbeatSent));
|
||||
}
|
||||
@@ -492,19 +502,15 @@ where
|
||||
let request_id = self.fetch_add_request_id();
|
||||
|
||||
// We don't care about the reply type when serializing
|
||||
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
|
||||
topic,
|
||||
payload,
|
||||
Some(request_id.copy()),
|
||||
))
|
||||
.expect("we should always be able to serialize a join topic message");
|
||||
let msg = serialize_msg(topic, payload, request_id.copy());
|
||||
|
||||
(request_id, msg)
|
||||
}
|
||||
|
||||
fn fetch_add_request_id(&mut self) -> OutboundRequestId {
|
||||
let next_id = self.next_request_id;
|
||||
self.next_request_id += 1;
|
||||
let next_id = self
|
||||
.next_request_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
OutboundRequestId(next_id)
|
||||
}
|
||||
@@ -685,6 +691,19 @@ enum EgressControlMessage<T> {
|
||||
Heartbeat(Empty),
|
||||
}
|
||||
|
||||
fn serialize_msg(
|
||||
topic: impl Into<String>,
|
||||
payload: impl Serialize,
|
||||
request_id: OutboundRequestId,
|
||||
) -> String {
|
||||
serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
|
||||
topic,
|
||||
payload,
|
||||
Some(request_id),
|
||||
))
|
||||
.expect("we should always be able to serialize a join topic message")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user