From ecce0244dcdbb5f7a8a6aa6f66d6b4ba4ec27331 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 26 Mar 2024 10:11:02 +1100 Subject: [PATCH] feat(phoenix-channel): fail on missing heartbeat after 5s (#4296) This PR fixes a bug and adds a missing feature to `phoenix-channel`. 1. Previously, we used to erroneously reset the heartbeat state on all sorts of empty replies, not just the specific one from the heartbeat. 2. We only failed on missing heartbeats when it was time to send the next one. With this PR, we correct the first bug and add a dedicated timeout of 5s for the heartbeat reply. --- rust/phoenix-channel/src/heartbeat.rs | 145 +++++++++++++++++--------- rust/phoenix-channel/src/lib.rs | 49 ++++++--- 2 files changed, 128 insertions(+), 66 deletions(-) diff --git a/rust/phoenix-channel/src/heartbeat.rs b/rust/phoenix-channel/src/heartbeat.rs index 6f61bba7f..dc49b7b2b 100644 --- a/rust/phoenix-channel/src/heartbeat.rs +++ b/rust/phoenix-channel/src/heartbeat.rs @@ -1,81 +1,93 @@ -use crate::{EgressControlMessage, OutboundRequestId}; +use crate::OutboundRequestId; +use futures::FutureExt; use std::{ pin::Pin, + sync::{atomic::AtomicU64, Arc}, task::{ready, Context, Poll}, time::Duration, }; use tokio::time::MissedTickBehavior; -const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); +pub const INTERVAL: Duration = Duration::from_secs(30); +pub const TIMEOUT: Duration = Duration::from_secs(5); pub struct Heartbeat { /// When to send the next heartbeat. interval: Pin>, + + timeout: Duration, + /// The ID of our heatbeat if we haven't received a reply yet. - id: Option, + pending: Option<(OutboundRequestId, Pin>)>, + + next_request_id: Arc, } impl Heartbeat { - pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool { - let Some(pending) = self.id.take() else { - return false; - }; - - if pending != id { - return false; - } - - self.id = None; - true - } - - pub fn set_id(&mut self, id: OutboundRequestId) { - self.id = Some(id); - } - - pub fn poll( - &mut self, - cx: &mut Context, - ) -> Poll, MissedLastHeartbeat>> { - ready!(self.interval.poll_tick(cx)); - - if self.id.is_some() { - self.id = None; - return Poll::Ready(Err(MissedLastHeartbeat {})); - } - - Poll::Ready(Ok(EgressControlMessage::Heartbeat(crate::Empty {}))) - } - - fn new(interval: Duration) -> Self { + pub fn new(interval: Duration, timeout: Duration, next_request_id: Arc) -> Self { let mut interval = tokio::time::interval(interval); interval.set_missed_tick_behavior(MissedTickBehavior::Skip); Self { interval: Box::pin(interval), - id: Default::default(), + pending: Default::default(), + next_request_id, + timeout, } } + + pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool { + match self.pending.as_ref() { + Some((pending, timeout)) if pending == &id && !dbg!(timeout.is_elapsed()) => { + self.pending = None; + + true + } + _ => false, + } + } + + pub fn poll( + &mut self, + cx: &mut Context, + ) -> Poll> { + if let Some((_, timeout)) = self.pending.as_mut() { + ready!(timeout.poll_unpin(cx)); + self.pending = None; + return Poll::Ready(Err(MissedLastHeartbeat {})); + } + + ready!(self.interval.poll_tick(cx)); + + let next_id = self + .next_request_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.pending = Some(( + OutboundRequestId(next_id), + Box::pin(tokio::time::sleep(self.timeout)), + )); + + Poll::Ready(Ok(OutboundRequestId(next_id))) + } } #[derive(Debug)] pub struct MissedLastHeartbeat {} -impl Default for Heartbeat { - fn default() -> Self { - Self::new(HEARTBEAT_INTERVAL) - } -} - #[cfg(test)] mod tests { use super::*; + use futures::future::Either; use std::{future::poll_fn, time::Instant}; + const INTERVAL: Duration = Duration::from_millis(30); + const TIMEOUT: Duration = Duration::from_millis(5); + #[tokio::test] async fn returns_heartbeat_after_interval() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(30)); - let _ = poll_fn(|cx| heartbeat.poll(cx)).await; // Tick once at startup. + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); // Tick once at startup. + heartbeat.maybe_handle_reply(id); let start = Instant::now(); @@ -84,15 +96,25 @@ mod tests { let elapsed = start.elapsed(); assert!(result.is_ok()); - assert!(elapsed >= Duration::from_millis(10)); + assert!(elapsed >= INTERVAL); } #[tokio::test] async fn fails_if_response_is_not_provided_before_next_poll() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::for_test(1)); + + let result = poll_fn(|cx| heartbeat.poll(cx)).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn ignores_other_ids() { + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + + let _ = poll_fn(|cx| heartbeat.poll(cx)).await; + heartbeat.maybe_handle_reply(OutboundRequestId::for_test(2)); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_err()); @@ -100,13 +122,34 @@ mod tests { #[tokio::test] async fn succeeds_if_response_is_provided_inbetween_polls() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); - let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::for_test(1)); - heartbeat.maybe_handle_reply(OutboundRequestId::for_test(1)); + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); + heartbeat.maybe_handle_reply(id); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_ok()); } + + #[tokio::test] + async fn fails_if_not_provided_within_timeout() { + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); + + let select = futures::future::select( + tokio::time::sleep(TIMEOUT * 2).boxed(), + poll_fn(|cx| heartbeat.poll(cx)), + ) + .await; + + match select { + Either::Left(((), _)) => panic!("timeout should not resolve"), + Either::Right((Ok(_), _)) => panic!("heartbeat should fail and not issue new ID"), + Either::Right((Err(_), _)) => {} + } + + let handled = heartbeat.maybe_handle_reply(id); + assert!(!handled); + } } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index e622a5fc1..e90c1278b 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -23,6 +23,8 @@ use tokio_tungstenite::{ }; pub use login_url::{LoginUrl, LoginUrlError}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; // TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway // See https://github.com/firezone/firezone/issues/2158 @@ -30,7 +32,7 @@ pub struct PhoenixChannel { state: State, waker: Option, pending_messages: VecDeque, - next_request_id: u64, + next_request_id: Arc, heartbeat: Heartbeat, @@ -208,6 +210,8 @@ where init_req: TInitReq, reconnect_backoff: ExponentialBackoff, ) -> Self { + let next_request_id = Arc::new(AtomicU64::new(0)); + Self { reconnect_backoff, url: url.clone(), @@ -222,8 +226,12 @@ where waker: None, pending_messages: Default::default(), _phantom: PhantomData, - next_request_id: 0, - heartbeat: Default::default(), + heartbeat: Heartbeat::new( + heartbeat::INTERVAL, + heartbeat::TIMEOUT, + next_request_id.clone(), + ), + next_request_id, pending_join_requests: Default::default(), login, init_req: init_req.clone(), @@ -447,10 +455,12 @@ where // Priority 3: Handle heartbeats. match self.heartbeat.poll(cx) { - Poll::Ready(Ok(msg)) => { - let (id, msg) = self.make_message("phoenix", msg); - self.pending_messages.push_back(msg); - self.heartbeat.set_id(id); + Poll::Ready(Ok(id)) => { + self.pending_messages.push_back(serialize_msg( + "phoenix", + EgressControlMessage::<()>::Heartbeat(Empty {}), + id.copy(), + )); return Poll::Ready(Ok(Event::HeartbeatSent)); } @@ -492,19 +502,15 @@ where let request_id = self.fetch_add_request_id(); // We don't care about the reply type when serializing - let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( - topic, - payload, - Some(request_id.copy()), - )) - .expect("we should always be able to serialize a join topic message"); + let msg = serialize_msg(topic, payload, request_id.copy()); (request_id, msg) } fn fetch_add_request_id(&mut self) -> OutboundRequestId { - let next_id = self.next_request_id; - self.next_request_id += 1; + let next_id = self + .next_request_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); OutboundRequestId(next_id) } @@ -685,6 +691,19 @@ enum EgressControlMessage { Heartbeat(Empty), } +fn serialize_msg( + topic: impl Into, + payload: impl Serialize, + request_id: OutboundRequestId, +) -> String { + serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( + topic, + payload, + Some(request_id), + )) + .expect("we should always be able to serialize a join topic message") +} + #[cfg(test)] mod tests { use super::*;