From 1ebee00699ecac6f33407dbbdd45502fe694356f Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 15 Jan 2025 15:40:32 +0100 Subject: [PATCH] fix(connlib): prevent time from going backwards (#7758) On a high level, `connlib` is a state machine that gets driven by a custom event-loop. For time-related actions, the state machine computes, when it would like to be woken next. The event-loop sets a timer for that value and emits this value when the timer fires. There is an edge-case where this may result in the time going backwards within the state machine. Specifically, if - for whatever reason - the state machine emits a time value that is in the past, the timer in the `Io` component will fire right away **but the `deadline` will point to the time in the past**. The only thing we are actually interested in is that the timer fires at all. Instead of passing back the deadline of the timer, we fetch the _current_ time and pass that back to the state machine as the current input. This ensures that we never jump back in time because Rust guarantees for calls to `Instant::now` to be monotonic. (https://doc.rust-lang.org/std/time/struct.Instant.html#:~:text=a%20measurement%20of%20a%20monotonically%20nondecreasing%20clock.) --- rust/connlib/tunnel/src/io.rs | 100 ++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 28 deletions(-) diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index cc22d70c4..96a7360fb 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -153,14 +153,21 @@ impl Io { if let Some(timeout) = self.timeout.as_mut() { if timeout.poll_unpin(cx).is_ready() { - let deadline = timeout.deadline().into(); + // Always emit `now` as the timeout value. + // This ensures that time within our state machine is always monotonic. + // If we were to use the `deadline` of the timer instead, time may go backwards. + // That is because it is valid to set a `Sleep` to a timestamp in the past. + // It will resolve immediately but it will still report the old timestamp as its deadline. + // To guard against this case, specifically call `Instant::now` here. + let now = Instant::now(); + self.timeout = None; // Clear the timeout. // Piggy back onto the timeout we already have. // It is not important when we call this, just needs to be called occasionally. - self.gso_queue.handle_timeout(deadline); + self.gso_queue.handle_timeout(now); - return Poll::Ready(Ok(Input::Timeout(deadline))); + return Poll::Ready(Ok(Input::Timeout(now))); } } @@ -334,48 +341,85 @@ mod tests { #[tokio::test] async fn timer_is_reset_after_it_fires() { - let now = Instant::now(); + let mut io = Io::for_test(); - let mut io = Io::new( - Arc::new(|_| Err(io::Error::other("not implemented"))), - Arc::new(|_| Err(io::Error::other("not implemented"))), - ); - io.set_tun(Box::new(DummyTun)); + let deadline = Instant::now() + Duration::from_secs(1); + io.reset_timeout(deadline); - io.reset_timeout(now + Duration::from_secs(1)); - - let poll_fn = poll_fn(|cx| { - io.poll( - cx, - // SAFETY: This is a test and we never receive packets here. - unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, - ) - }) - .await - .unwrap(); - - let Input::Timeout(timeout) = poll_fn else { + let Input::Timeout(timeout) = io.next().await else { panic!("Unexpected result"); }; - assert_eq!(timeout, now + Duration::from_secs(1)); + assert!(timeout >= deadline, "timer expire after deadline"); - let poll = io.poll( - &mut Context::from_waker(noop_waker_ref()), - // SAFETY: This is a test and we never receive packets here. - unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, - ); + let poll = io.poll_test(); assert!(poll.is_pending()); assert!(io.timeout.is_none()); } + #[tokio::test] + async fn emits_now_in_case_timeout_is_in_the_past() { + let now = Instant::now(); + let mut io = Io::for_test(); + + io.reset_timeout(now - Duration::from_secs(10)); + + let Input::Timeout(timeout) = io.next().await else { + panic!("Unexpected result"); + }; + + assert!(timeout.duration_since(now) < Duration::from_millis(1)); + } + static mut DUMMY_BUF: Buffers = Buffers { ip: Vec::new(), udp4: Vec::new(), udp6: Vec::new(), }; + /// Helper functions to make the test more concise. + impl Io { + fn for_test() -> Io { + let mut io = Io::new( + Arc::new(|_| Err(io::Error::other("not implemented"))), + Arc::new(|_| Err(io::Error::other("not implemented"))), + ); + io.set_tun(Box::new(DummyTun)); + + io + } + + async fn next( + &mut self, + ) -> Input, impl Iterator>> + { + poll_fn(|cx| { + self.poll( + cx, + // SAFETY: This is a test and we never receive packets here. + unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, + ) + }) + .await + .unwrap() + } + + fn poll_test( + &mut self, + ) -> Poll< + io::Result< + Input, impl Iterator>>, + >, + > { + self.poll( + &mut Context::from_waker(noop_waker_ref()), + // SAFETY: This is a test and we never receive packets here. + unsafe { &mut *addr_of_mut!(DUMMY_BUF) }, + ) + } + } + struct DummyTun; impl Tun for DummyTun {