From 4ffc49eef9765c4f19bcfa6750022cfa6ebc40a0 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 26 Jun 2024 06:19:56 +1000 Subject: [PATCH] fix(snownet): ensure failed refresh requests invalidate allocation (#5538) Whilst we had a unit-test for this behaviour, it was written poorly and didn't assert on the correct thing. Instead, I happened to pass because we advanced time far enough to trigger the actual expiry of the allocation instead of directly expiring it upon the last failed retry of the refresh request. Re-writing this test then surfaced that we were in fact no invalidating the allocation correctly. In real-time, this represents a difference of 5 minutes within which a client may try to use a relay candidate that is in fact no longer working. Related: #5519. --- rust/connlib/snownet/src/allocation.rs | 59 +++++++++++--------------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index a76acfe4e..6d6df5c0e 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -575,9 +575,16 @@ impl Allocation { tracing::debug!(id = ?request.transaction_id(), method = %request.method(), %dst, "Request timed out after {backoff_duration:?}, re-sending"); let needs_auth = request.method() != BINDING; + let is_refresh = request.method() == REFRESH; if needs_auth { - self.authenticate_and_queue(request, Some(backoff)); + let queued = self.authenticate_and_queue(request, Some(backoff)); + + // If we fail to queue the refresh message because we've exceeded our backoff, give up. + if !queued && is_refresh { + self.invalidate_allocation(); + } + continue; } @@ -587,12 +594,7 @@ impl Allocation { if let Some(refresh_at) = self.refresh_allocation_at() { if (now >= refresh_at) && !self.refresh_in_flight() { tracing::debug!("Allocation is due for a refresh"); - let queued = self.authenticate_and_queue(make_refresh_request(), None); - - // If we fail to queue the refresh message because we've exceeded our backoff, give up. - if !queued { - self.invalidate_allocation(); - } + self.authenticate_and_queue(make_refresh_request(), None); } } @@ -908,9 +910,7 @@ impl Allocation { }; let authenticated_message = authenticate(message, credentials); - self.queue(dst, authenticated_message, backoff); - - true + self.queue(dst, authenticated_message, backoff) } fn queue( @@ -2080,6 +2080,11 @@ mod tests { #[test] fn timed_out_refresh_requests_invalid_candidates() { + let _guard = tracing_subscriber::fmt() + .with_env_filter("trace") + .with_test_writer() + .set_default(); + let start = Instant::now(); let mut allocation = Allocation::for_test_ip4(start).with_binding_response(PEER1); @@ -2103,18 +2108,10 @@ mod tests { } // Simulate refresh timing out - loop { - let timeout = allocation.poll_timeout().unwrap(); - allocation.handle_timeout(timeout); - - if let Some(refresh) = allocation.next_message() { - assert_eq!(refresh.method(), REFRESH); - } else { - break; - } + for _ in backoff::steps(start) { + allocation.handle_timeout(allocation.poll_timeout().unwrap()); } - assert!(allocation.poll_timeout().is_none()); assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ @@ -2125,24 +2122,16 @@ mod tests { } #[test] - fn expires_allocation_invalidates_candidaets() { + fn expires_allocation_invalidates_candidates() { let start = Instant::now(); - let mut allocation = Allocation::for_test_ip4(start).with_binding_response(PEER1); + let mut allocation = Allocation::for_test_ip4(start) + .with_binding_response(PEER1) + .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6]); - // Make an allocation - { - let allocate = allocation.next_message().unwrap(); - allocation.handle_test_input_ip4( - &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), - start, - ); - let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); - } + let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); allocation.handle_timeout(start + ALLOCATION_LIFETIME); - assert!(allocation.poll_timeout().is_none()); - assert!(allocation.next_message().is_none()); assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ @@ -2517,14 +2506,14 @@ mod tests { fn with_binding_response(mut self, srflx_addr: SocketAddr) -> Self { let binding = self.next_message().unwrap(); - self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), Instant::now()); + self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), self.last_now); self } fn with_allocate_response(mut self, relay_addrs: &[SocketAddr]) -> Self { let allocate = self.next_message().unwrap(); - self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), Instant::now()); + self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), self.last_now); self }