diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index a76acfe4e..6d6df5c0e 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -575,9 +575,16 @@ impl Allocation { tracing::debug!(id = ?request.transaction_id(), method = %request.method(), %dst, "Request timed out after {backoff_duration:?}, re-sending"); let needs_auth = request.method() != BINDING; + let is_refresh = request.method() == REFRESH; if needs_auth { - self.authenticate_and_queue(request, Some(backoff)); + let queued = self.authenticate_and_queue(request, Some(backoff)); + + // If we fail to queue the refresh message because we've exceeded our backoff, give up. + if !queued && is_refresh { + self.invalidate_allocation(); + } + continue; } @@ -587,12 +594,7 @@ impl Allocation { if let Some(refresh_at) = self.refresh_allocation_at() { if (now >= refresh_at) && !self.refresh_in_flight() { tracing::debug!("Allocation is due for a refresh"); - let queued = self.authenticate_and_queue(make_refresh_request(), None); - - // If we fail to queue the refresh message because we've exceeded our backoff, give up. - if !queued { - self.invalidate_allocation(); - } + self.authenticate_and_queue(make_refresh_request(), None); } } @@ -908,9 +910,7 @@ impl Allocation { }; let authenticated_message = authenticate(message, credentials); - self.queue(dst, authenticated_message, backoff); - - true + self.queue(dst, authenticated_message, backoff) } fn queue( @@ -2080,6 +2080,11 @@ mod tests { #[test] fn timed_out_refresh_requests_invalid_candidates() { + let _guard = tracing_subscriber::fmt() + .with_env_filter("trace") + .with_test_writer() + .set_default(); + let start = Instant::now(); let mut allocation = Allocation::for_test_ip4(start).with_binding_response(PEER1); @@ -2103,18 +2108,10 @@ mod tests { } // Simulate refresh timing out - loop { - let timeout = allocation.poll_timeout().unwrap(); - allocation.handle_timeout(timeout); - - if let Some(refresh) = allocation.next_message() { - assert_eq!(refresh.method(), REFRESH); - } else { - break; - } + for _ in backoff::steps(start) { + allocation.handle_timeout(allocation.poll_timeout().unwrap()); } - assert!(allocation.poll_timeout().is_none()); assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ @@ -2125,24 +2122,16 @@ mod tests { } #[test] - fn expires_allocation_invalidates_candidaets() { + fn expires_allocation_invalidates_candidates() { let start = Instant::now(); - let mut allocation = Allocation::for_test_ip4(start).with_binding_response(PEER1); + let mut allocation = Allocation::for_test_ip4(start) + .with_binding_response(PEER1) + .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6]); - // Make an allocation - { - let allocate = allocation.next_message().unwrap(); - allocation.handle_test_input_ip4( - &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), - start, - ); - let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); - } + let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); allocation.handle_timeout(start + ALLOCATION_LIFETIME); - assert!(allocation.poll_timeout().is_none()); - assert!(allocation.next_message().is_none()); assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ @@ -2517,14 +2506,14 @@ mod tests { fn with_binding_response(mut self, srflx_addr: SocketAddr) -> Self { let binding = self.next_message().unwrap(); - self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), Instant::now()); + self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), self.last_now); self } fn with_allocate_response(mut self, relay_addrs: &[SocketAddr]) -> Self { let allocate = self.next_message().unwrap(); - self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), Instant::now()); + self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), self.last_now); self }