diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 43dd8b061..099983346 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -9,7 +9,7 @@ use connlib_shared::{ CallbackErrorFacade, Callbacks, Error, Result, }; use device_channel::Device; -use futures_util::{future::BoxFuture, task::AtomicWaker, FutureExt}; +use futures_util::{task::AtomicWaker, FutureExt}; use peer::PacketTransform; use peer_store::PeerStore; use snownet::{Node, Server}; @@ -19,6 +19,7 @@ use std::{ fmt, hash::Hash, io, + pin::Pin, task::{ready, Context, Poll}, time::{Duration, Instant}, }; @@ -122,6 +123,11 @@ where Poll::Pending => {} } + // After any state change, check what the new timeout is and reset it if necessary. + if self.connections_state.poll_timeout(cx).is_ready() { + cx.waker().wake_by_ref() + } + Poll::Pending } } @@ -180,6 +186,11 @@ where } } + // After any state change, check what the new timeout is and reset it if necessary. + if self.connections_state.poll_timeout(cx).is_ready() { + cx.waker().wake_by_ref() + } + Poll::Pending } } @@ -229,7 +240,7 @@ where struct ConnectionState { pub node: Node, write_buf: Box<[u8; MAX_UDP_SIZE]>, - connection_pool_timeout: BoxFuture<'static, std::time::Instant>, + timeout: Option>>, stats_timer: tokio::time::Interval, sockets: Sockets, } @@ -242,9 +253,9 @@ where Ok(ConnectionState { node: Node::new(private_key, std::time::Instant::now()), write_buf: Box::new([0; MAX_UDP_SIZE]), - connection_pool_timeout: sleep_until(std::time::Instant::now()).boxed(), sockets: Sockets::new()?, stats_timer: tokio::time::interval(Duration::from_secs(60)), + timeout: None, }) } @@ -335,17 +346,6 @@ where } fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Poll::Ready(prev_timeout) = self.connection_pool_timeout.poll_unpin(cx) { - self.node.handle_timeout(prev_timeout); - if let Some(new_timeout) = self.node.poll_timeout() { - debug_assert_ne!(prev_timeout, new_timeout, "Timer busy loop!"); - - self.connection_pool_timeout = sleep_until(new_timeout).boxed(); - } - - cx.waker().wake_by_ref(); - } - if self.stats_timer.poll_tick(cx).is_ready() { let (node_stats, conn_stats) = self.node.stats(); @@ -386,6 +386,33 @@ where Poll::Pending } + + fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if let Some(timeout) = self.node.poll_timeout() { + let timeout = tokio::time::Instant::from_std(timeout); + + match self.timeout.as_mut() { + Some(existing_timeout) if existing_timeout.deadline() != timeout => { + existing_timeout.as_mut().reset(timeout) + } + Some(_) => {} + None => self.timeout = Some(Box::pin(tokio::time::sleep_until(timeout))), + } + } + + if let Some(timeout) = self.timeout.as_mut() { + ready!(timeout.poll_unpin(cx)); + self.node.handle_timeout(timeout.deadline().into()); + + return Poll::Ready(()); + } + + // Technically, we should set a waker here because we don't have a timer. + // But the only place where we set a timer is a few lines up. + // That is the same path that will re-poll it so there is no point in using a waker. + // We might want to consider making a `MaybeSleep` type that encapsulates a waker so we don't need to think about it as hard. + Poll::Pending + } } pub enum Event { @@ -403,9 +430,3 @@ pub enum Event { SendPacket(IpPacket<'static>), StopPeer(TId), } - -async fn sleep_until(deadline: Instant) -> Instant { - tokio::time::sleep_until(deadline.into()).await; - - deadline -}