diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 3ce2aabf9..3d94ef5b6 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -241,17 +241,7 @@ impl Allocation { self.channel_bindings.handle_failed_binding(channel); } REFRESH => { - if let Some(candidate) = self.ip4_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) - } - - if let Some(candidate) = self.ip6_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) - } - - self.channel_bindings.clear(); - self.allocation_lifetime = None; - + self.invalidate_allocation(); self.authenticate_and_queue(make_allocate_request()); } _ => {} @@ -402,6 +392,13 @@ impl Allocation { pub fn handle_timeout(&mut self, now: Instant) { self.update_now(now); + if self + .allocation_expires_at() + .is_some_and(|expires_at| now >= expires_at) + { + self.invalidate_allocation(); + } + while let Some(timed_out_request) = self.sent_requests .iter() @@ -420,19 +417,19 @@ impl Allocation { } if let Some(refresh_at) = self.refresh_allocation_at() { - if now >= refresh_at { + if (now >= refresh_at) && !self.refresh_in_flight() { tracing::debug!("Allocation is due for a refresh"); self.authenticate_and_queue(make_refresh_request()); } } - let refresh_messages = self + let channel_refresh_messages = self .channel_bindings .channels_to_refresh(now, |number| self.channel_binding_in_flight(number)) .map(|(number, peer)| make_channel_bind_request(peer, number)) .collect::>(); // Need to allocate here to satisfy borrow-checker. Number of channel refresh messages should be small so this shouldn't be a big impact. - for message in refresh_messages { + for message in channel_refresh_messages { self.authenticate_and_queue(message); } @@ -448,7 +445,11 @@ impl Allocation { } pub fn poll_timeout(&self) -> Option { - let mut earliest_timeout = self.refresh_allocation_at(); + let mut earliest_timeout = if !self.refresh_in_flight() { + self.refresh_allocation_at() + } else { + None + }; for (_, (_, sent_at, backoff)) in self.sent_requests.iter() { earliest_timeout = earliest(earliest_timeout, Some(*sent_at + *backoff)); @@ -522,6 +523,26 @@ impl Allocation { Some(received_at + refresh_after) } + fn allocation_expires_at(&self) -> Option { + let (received_at, lifetime) = self.allocation_lifetime?; + + Some(received_at + lifetime) + } + + fn invalidate_allocation(&mut self) { + if let Some(candidate) = self.ip4_allocation.take() { + self.events.push_back(CandidateEvent::Invalid(candidate)) + } + + if let Some(candidate) = self.ip6_allocation.take() { + self.events.push_back(CandidateEvent::Invalid(candidate)) + } + + self.channel_bindings.clear(); + self.allocation_lifetime = None; + self.sent_requests.clear(); + } + /// Checks whether the given socket is part of this allocation. pub fn has_socket(&self, socket: SocketAddr) -> bool { let is_ip4 = self.ip4_socket().is_some_and(|s| s.address() == socket); @@ -577,6 +598,12 @@ impl Allocation { .any(|(r, _, _)| r.method() == ALLOCATE) } + fn refresh_in_flight(&self) -> bool { + self.sent_requests + .values() + .any(|(r, _, _)| r.method() == REFRESH) + } + /// Check whether this allocation is suspended. /// /// We call it suspended if we have given up making an allocation due to some error. @@ -1601,6 +1628,22 @@ mod tests { assert_eq!(next_msg.method(), REFRESH) } + #[test] + fn allocation_is_refreshed_only_once() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + let refresh_at = allocation.poll_timeout().unwrap(); + + allocation.handle_timeout(refresh_at); + assert!(allocation.poll_timeout().unwrap() > refresh_at); + } + #[test] fn failed_refresh_resets_allocation_lifetime() { let mut allocation = Allocation::for_test(Instant::now()); @@ -1666,6 +1709,59 @@ mod tests { assert!(allocation.is_suspended()) } + #[test] + fn timed_out_refresh_requests_invalid_candidates() { + let start = Instant::now(); + let mut allocation = Allocation::for_test(start); + + // Make an allocation + { + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + start, + ); + let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); + } + + let allocated_at = start; + + // Test that we refresh it. + { + let refresh_at = allocation.poll_timeout().unwrap(); + allocation.handle_timeout(refresh_at); + + let refresh = allocation.next_message().unwrap(); + assert_eq!(refresh.method(), REFRESH); + } + + // Simulate refresh timing out + loop { + let timeout = allocation.poll_timeout().unwrap(); + allocation.handle_timeout(timeout); + + if timeout > (allocated_at + ALLOCATION_LIFETIME) { + break; + } + + let refresh = allocation.next_message().unwrap(); + assert_eq!(refresh.method(), REFRESH); + } + + assert!( + allocation.next_message().is_none(), + "expect to not queue another refresh message if we are past the allocation lifetime" + ); + assert!(allocation.poll_timeout().is_none()); + assert_eq!( + iter::from_fn(|| allocation.poll_event()).collect::>(), + vec![ + CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + ] + ) + } + fn ch(peer: SocketAddr, now: Instant) -> Channel { Channel { peer,