diff --git a/rust/phoenix-channel/src/heartbeat.rs b/rust/phoenix-channel/src/heartbeat.rs index 6f61bba7f..dc49b7b2b 100644 --- a/rust/phoenix-channel/src/heartbeat.rs +++ b/rust/phoenix-channel/src/heartbeat.rs @@ -1,81 +1,93 @@ -use crate::{EgressControlMessage, OutboundRequestId}; +use crate::OutboundRequestId; +use futures::FutureExt; use std::{ pin::Pin, + sync::{atomic::AtomicU64, Arc}, task::{ready, Context, Poll}, time::Duration, }; use tokio::time::MissedTickBehavior; -const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); +pub const INTERVAL: Duration = Duration::from_secs(30); +pub const TIMEOUT: Duration = Duration::from_secs(5); pub struct Heartbeat { /// When to send the next heartbeat. interval: Pin>, + + timeout: Duration, + /// The ID of our heatbeat if we haven't received a reply yet. - id: Option, + pending: Option<(OutboundRequestId, Pin>)>, + + next_request_id: Arc, } impl Heartbeat { - pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool { - let Some(pending) = self.id.take() else { - return false; - }; - - if pending != id { - return false; - } - - self.id = None; - true - } - - pub fn set_id(&mut self, id: OutboundRequestId) { - self.id = Some(id); - } - - pub fn poll( - &mut self, - cx: &mut Context, - ) -> Poll, MissedLastHeartbeat>> { - ready!(self.interval.poll_tick(cx)); - - if self.id.is_some() { - self.id = None; - return Poll::Ready(Err(MissedLastHeartbeat {})); - } - - Poll::Ready(Ok(EgressControlMessage::Heartbeat(crate::Empty {}))) - } - - fn new(interval: Duration) -> Self { + pub fn new(interval: Duration, timeout: Duration, next_request_id: Arc) -> Self { let mut interval = tokio::time::interval(interval); interval.set_missed_tick_behavior(MissedTickBehavior::Skip); Self { interval: Box::pin(interval), - id: Default::default(), + pending: Default::default(), + next_request_id, + timeout, } } + + pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool { + match self.pending.as_ref() { + Some((pending, timeout)) if pending == &id && !dbg!(timeout.is_elapsed()) => { + self.pending = None; + + true + } + _ => false, + } + } + + pub fn poll( + &mut self, + cx: &mut Context, + ) -> Poll> { + if let Some((_, timeout)) = self.pending.as_mut() { + ready!(timeout.poll_unpin(cx)); + self.pending = None; + return Poll::Ready(Err(MissedLastHeartbeat {})); + } + + ready!(self.interval.poll_tick(cx)); + + let next_id = self + .next_request_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.pending = Some(( + OutboundRequestId(next_id), + Box::pin(tokio::time::sleep(self.timeout)), + )); + + Poll::Ready(Ok(OutboundRequestId(next_id))) + } } #[derive(Debug)] pub struct MissedLastHeartbeat {} -impl Default for Heartbeat { - fn default() -> Self { - Self::new(HEARTBEAT_INTERVAL) - } -} - #[cfg(test)] mod tests { use super::*; + use futures::future::Either; use std::{future::poll_fn, time::Instant}; + const INTERVAL: Duration = Duration::from_millis(30); + const TIMEOUT: Duration = Duration::from_millis(5); + #[tokio::test] async fn returns_heartbeat_after_interval() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(30)); - let _ = poll_fn(|cx| heartbeat.poll(cx)).await; // Tick once at startup. + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); // Tick once at startup. + heartbeat.maybe_handle_reply(id); let start = Instant::now(); @@ -84,15 +96,25 @@ mod tests { let elapsed = start.elapsed(); assert!(result.is_ok()); - assert!(elapsed >= Duration::from_millis(10)); + assert!(elapsed >= INTERVAL); } #[tokio::test] async fn fails_if_response_is_not_provided_before_next_poll() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::for_test(1)); + + let result = poll_fn(|cx| heartbeat.poll(cx)).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn ignores_other_ids() { + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + + let _ = poll_fn(|cx| heartbeat.poll(cx)).await; + heartbeat.maybe_handle_reply(OutboundRequestId::for_test(2)); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_err()); @@ -100,13 +122,34 @@ mod tests { #[tokio::test] async fn succeeds_if_response_is_provided_inbetween_polls() { - let mut heartbeat = Heartbeat::new(Duration::from_millis(10)); + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); - let _ = poll_fn(|cx| heartbeat.poll(cx)).await; - heartbeat.set_id(OutboundRequestId::for_test(1)); - heartbeat.maybe_handle_reply(OutboundRequestId::for_test(1)); + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); + heartbeat.maybe_handle_reply(id); let result = poll_fn(|cx| heartbeat.poll(cx)).await; assert!(result.is_ok()); } + + #[tokio::test] + async fn fails_if_not_provided_within_timeout() { + let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0))); + + let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); + + let select = futures::future::select( + tokio::time::sleep(TIMEOUT * 2).boxed(), + poll_fn(|cx| heartbeat.poll(cx)), + ) + .await; + + match select { + Either::Left(((), _)) => panic!("timeout should not resolve"), + Either::Right((Ok(_), _)) => panic!("heartbeat should fail and not issue new ID"), + Either::Right((Err(_), _)) => {} + } + + let handled = heartbeat.maybe_handle_reply(id); + assert!(!handled); + } } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index e622a5fc1..e90c1278b 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -23,6 +23,8 @@ use tokio_tungstenite::{ }; pub use login_url::{LoginUrl, LoginUrlError}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; // TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway // See https://github.com/firezone/firezone/issues/2158 @@ -30,7 +32,7 @@ pub struct PhoenixChannel { state: State, waker: Option, pending_messages: VecDeque, - next_request_id: u64, + next_request_id: Arc, heartbeat: Heartbeat, @@ -208,6 +210,8 @@ where init_req: TInitReq, reconnect_backoff: ExponentialBackoff, ) -> Self { + let next_request_id = Arc::new(AtomicU64::new(0)); + Self { reconnect_backoff, url: url.clone(), @@ -222,8 +226,12 @@ where waker: None, pending_messages: Default::default(), _phantom: PhantomData, - next_request_id: 0, - heartbeat: Default::default(), + heartbeat: Heartbeat::new( + heartbeat::INTERVAL, + heartbeat::TIMEOUT, + next_request_id.clone(), + ), + next_request_id, pending_join_requests: Default::default(), login, init_req: init_req.clone(), @@ -447,10 +455,12 @@ where // Priority 3: Handle heartbeats. match self.heartbeat.poll(cx) { - Poll::Ready(Ok(msg)) => { - let (id, msg) = self.make_message("phoenix", msg); - self.pending_messages.push_back(msg); - self.heartbeat.set_id(id); + Poll::Ready(Ok(id)) => { + self.pending_messages.push_back(serialize_msg( + "phoenix", + EgressControlMessage::<()>::Heartbeat(Empty {}), + id.copy(), + )); return Poll::Ready(Ok(Event::HeartbeatSent)); } @@ -492,19 +502,15 @@ where let request_id = self.fetch_add_request_id(); // We don't care about the reply type when serializing - let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( - topic, - payload, - Some(request_id.copy()), - )) - .expect("we should always be able to serialize a join topic message"); + let msg = serialize_msg(topic, payload, request_id.copy()); (request_id, msg) } fn fetch_add_request_id(&mut self) -> OutboundRequestId { - let next_id = self.next_request_id; - self.next_request_id += 1; + let next_id = self + .next_request_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); OutboundRequestId(next_id) } @@ -685,6 +691,19 @@ enum EgressControlMessage { Heartbeat(Empty), } +fn serialize_msg( + topic: impl Into, + payload: impl Serialize, + request_id: OutboundRequestId, +) -> String { + serde_json::to_string(&PhoenixMessage::<_, ()>::new_message( + topic, + payload, + Some(request_id), + )) + .expect("we should always be able to serialize a join topic message") +} + #[cfg(test)] mod tests { use super::*;