diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index cc22d70c4..96a7360fb 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -153,14 +153,21 @@ impl Io { if let Some(timeout) = self.timeout.as_mut() { if timeout.poll_unpin(cx).is_ready() { - let deadline = timeout.deadline().into(); + // Always emit `now` as the timeout value. + // This ensures that time within our state machine is always monotonic. + // If we were to use the `deadline` of the timer instead, time may go backwards. + // That is because it is valid to set a `Sleep` to a timestamp in the past. + // It will resolve immediately but it will still report the old timestamp as its deadline. + // To guard against this case, specifically call `Instant::now` here. + let now = Instant::now(); + self.timeout = None; // Clear the timeout. // Piggy back onto the timeout we already have. // It is not important when we call this, just needs to be called occasionally. - self.gso_queue.handle_timeout(deadline); + self.gso_queue.handle_timeout(now); - return Poll::Ready(Ok(Input::Timeout(deadline))); + return Poll::Ready(Ok(Input::Timeout(now))); } } @@ -334,48 +341,85 @@ mod tests { #[tokio::test] async fn timer_is_reset_after_it_fires() { - let now = Instant::now(); + let mut io = Io::for_test(); - let mut io = Io::new( - Arc::new(|_| Err(io::Error::other("not implemented"))), - Arc::new(|_| Err(io::Error::other("not implemented"))), - ); - io.set_tun(Box::new(DummyTun)); + let deadline = Instant::now() + Duration::from_secs(1); + io.reset_timeout(deadline); - io.reset_timeout(now + Duration::from_secs(1)); - - let poll_fn = poll_fn(|cx| { - io.poll( - cx, - // SAFETY: This is a test and we never receive packets here. - unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, - ) - }) - .await - .unwrap(); - - let Input::Timeout(timeout) = poll_fn else { + let Input::Timeout(timeout) = io.next().await else { panic!("Unexpected result"); }; - assert_eq!(timeout, now + Duration::from_secs(1)); + assert!(timeout >= deadline, "timer expire after deadline"); - let poll = io.poll( - &mut Context::from_waker(noop_waker_ref()), - // SAFETY: This is a test and we never receive packets here. - unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, - ); + let poll = io.poll_test(); assert!(poll.is_pending()); assert!(io.timeout.is_none()); } + #[tokio::test] + async fn emits_now_in_case_timeout_is_in_the_past() { + let now = Instant::now(); + let mut io = Io::for_test(); + + io.reset_timeout(now - Duration::from_secs(10)); + + let Input::Timeout(timeout) = io.next().await else { + panic!("Unexpected result"); + }; + + assert!(timeout.duration_since(now) < Duration::from_millis(1)); + } + static mut DUMMY_BUF: Buffers = Buffers { ip: Vec::new(), udp4: Vec::new(), udp6: Vec::new(), }; + /// Helper functions to make the test more concise. + impl Io { + fn for_test() -> Io { + let mut io = Io::new( + Arc::new(|_| Err(io::Error::other("not implemented"))), + Arc::new(|_| Err(io::Error::other("not implemented"))), + ); + io.set_tun(Box::new(DummyTun)); + + io + } + + async fn next( + &mut self, + ) -> Input, impl Iterator>> + { + poll_fn(|cx| { + self.poll( + cx, + // SAFETY: This is a test and we never receive packets here. + unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, + ) + }) + .await + .unwrap() + } + + fn poll_test( + &mut self, + ) -> Poll< + io::Result< + Input, impl Iterator>>, + >, + > { + self.poll( + &mut Context::from_waker(noop_waker_ref()), + // SAFETY: This is a test and we never receive packets here. + unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, + ) + } + } + struct DummyTun; impl Tun for DummyTun {