fix(connlib): prevent time from going backwards (#7758)

On a high level, `connlib` is a state machine that gets driven by a
custom event-loop. For time-related actions, the state machine computes,
when it would like to be woken next. The event-loop sets a timer for
that value and emits this value when the timer fires.

There is an edge-case where this may result in the time going backwards
within the state machine. Specifically, if - for whatever reason - the
state machine emits a time value that is in the past, the timer in the
`Io` component will fire right away **but the `deadline` will point to
the time in the past**.

The only thing we are actually interested in is that the timer fires at
all. Instead of passing back the deadline of the timer, we fetch the
_current_ time and pass that back to the state machine as the current
input. This ensures that we never jump back in time because Rust
guarantees for calls to `Instant::now` to be monotonic.
(https://doc.rust-lang.org/std/time/struct.Instant.html#:~:text=a%20measurement%20of%20a%20monotonically%20nondecreasing%20clock.)
This commit is contained in:
Thomas Eizinger
2025-01-15 15:40:32 +01:00
committed by GitHub
parent 17af9bc28f
commit 1ebee00699

View File

@@ -153,14 +153,21 @@ impl Io {
if let Some(timeout) = self.timeout.as_mut() {
if timeout.poll_unpin(cx).is_ready() {
let deadline = timeout.deadline().into();
// Always emit `now` as the timeout value.
// This ensures that time within our state machine is always monotonic.
// If we were to use the `deadline` of the timer instead, time may go backwards.
// That is because it is valid to set a `Sleep` to a timestamp in the past.
// It will resolve immediately but it will still report the old timestamp as its deadline.
// To guard against this case, specifically call `Instant::now` here.
let now = Instant::now();
self.timeout = None; // Clear the timeout.
// Piggy back onto the timeout we already have.
// It is not important when we call this, just needs to be called occasionally.
self.gso_queue.handle_timeout(deadline);
self.gso_queue.handle_timeout(now);
return Poll::Ready(Ok(Input::Timeout(deadline)));
return Poll::Ready(Ok(Input::Timeout(now)));
}
}
@@ -334,48 +341,85 @@ mod tests {
#[tokio::test]
async fn timer_is_reset_after_it_fires() {
let now = Instant::now();
let mut io = Io::for_test();
let mut io = Io::new(
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(|_| Err(io::Error::other("not implemented"))),
);
io.set_tun(Box::new(DummyTun));
let deadline = Instant::now() + Duration::from_secs(1);
io.reset_timeout(deadline);
io.reset_timeout(now + Duration::from_secs(1));
let poll_fn = poll_fn(|cx| {
io.poll(
cx,
// SAFETY: This is a test and we never receive packets here.
unsafe { &mut *addr_of_mut!(DUMMY_BUF) },
)
})
.await
.unwrap();
let Input::Timeout(timeout) = poll_fn else {
let Input::Timeout(timeout) = io.next().await else {
panic!("Unexpected result");
};
assert_eq!(timeout, now + Duration::from_secs(1));
assert!(timeout >= deadline, "timer expire after deadline");
let poll = io.poll(
&mut Context::from_waker(noop_waker_ref()),
// SAFETY: This is a test and we never receive packets here.
unsafe { &mut *addr_of_mut!(DUMMY_BUF) },
);
let poll = io.poll_test();
assert!(poll.is_pending());
assert!(io.timeout.is_none());
}
#[tokio::test]
async fn emits_now_in_case_timeout_is_in_the_past() {
let now = Instant::now();
let mut io = Io::for_test();
io.reset_timeout(now - Duration::from_secs(10));
let Input::Timeout(timeout) = io.next().await else {
panic!("Unexpected result");
};
assert!(timeout.duration_since(now) < Duration::from_millis(1));
}
static mut DUMMY_BUF: Buffers = Buffers {
ip: Vec::new(),
udp4: Vec::new(),
udp6: Vec::new(),
};
/// Helper functions to make the test more concise.
impl Io {
fn for_test() -> Io {
let mut io = Io::new(
Arc::new(|_| Err(io::Error::other("not implemented"))),
Arc::new(|_| Err(io::Error::other("not implemented"))),
);
io.set_tun(Box::new(DummyTun));
io
}
async fn next(
&mut self,
) -> Input<impl Iterator<Item = IpPacket>, impl Iterator<Item = DatagramIn<'static>>>
{
poll_fn(|cx| {
self.poll(
cx,
// SAFETY: This is a test and we never receive packets here.
unsafe { &mut *addr_of_mut!(DUMMY_BUF) },
)
})
.await
.unwrap()
}
fn poll_test(
&mut self,
) -> Poll<
io::Result<
Input<impl Iterator<Item = IpPacket>, impl Iterator<Item = DatagramIn<'static>>>,
>,
> {
self.poll(
&mut Context::from_waker(noop_waker_ref()),
// SAFETY: This is a test and we never receive packets here.
unsafe { &mut *addr_of_mut!(DUMMY_BUF) },
)
}
}
struct DummyTun;
impl Tun for DummyTun {