diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b4d07c484..3ba7266ce 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -5944,7 +5944,6 @@ dependencies = [ name = "snownet" version = "0.1.0" dependencies = [ - "backoff", "boringtun", "bytecodec", "bytes", diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index d2a1f88b8..d4b95c8c4 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -5,7 +5,6 @@ edition = { workspace = true } license = { workspace = true } [dependencies] -backoff = { workspace = true } boringtun = { workspace = true } bytecodec = { workspace = true } bytes = { workspace = true } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 17db4c50b..42a0123d5 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -1,9 +1,7 @@ use crate::{ backoff::{self, ExponentialBackoff}, node::{SessionId, Transmit}, - utils::earliest, }; -use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; use firezone_logging::{err_with_src, std_dyn_err}; use hex_display::HexDisplayExt as _; @@ -11,6 +9,7 @@ use rand::random; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use std::{ collections::{BTreeMap, VecDeque}, + iter, net::{SocketAddr, SocketAddrV4, SocketAddrV6}, time::{Duration, Instant}, }; @@ -81,22 +80,11 @@ pub struct Allocation { buffered_transmits: VecDeque>, events: VecDeque, - sent_requests: BTreeMap< - TransactionId, - ( - SocketAddr, - Message, - Instant, - Duration, - ExponentialBackoff, - ), - >, + sent_requests: BTreeMap, ExponentialBackoff)>, channel_bindings: ChannelBindings, buffered_channel_bindings: AllocRingBuffer, - last_now: Instant, - credentials: Option, explicit_failure: Option, @@ -243,14 +231,13 @@ impl Allocation { }), allocation_lifetime: Default::default(), channel_bindings: Default::default(), - last_now: now, buffered_channel_bindings: AllocRingBuffer::new(100), software: Software::new(format!("snownet; session={session_id}")) .expect("description has less then 128 chars"), explicit_failure: Default::default(), }; - allocation.send_binding_requests(); + allocation.send_binding_requests(now); allocation } @@ -266,8 +253,6 @@ impl Allocation { /// In case refreshing the allocation fails, we will attempt to make a new one. #[tracing::instrument(level = "debug", skip_all, fields(active_socket = ?self.active_socket))] pub fn refresh(&mut self, now: Instant) { - self.update_now(now); - if !self.has_allocation() && self.allocate_in_flight() { tracing::debug!("Not refreshing allocation because we are already making one"); return; @@ -277,13 +262,13 @@ impl Allocation { tracing::debug!("Attempting to make a new allocation"); self.active_socket = None; - self.send_binding_requests(); + self.send_binding_requests(now); return; } tracing::debug!("Refreshing allocation"); - self.authenticate_and_queue(make_refresh_request(self.software.clone()), None); + self.authenticate_and_queue(make_refresh_request(self.software.clone()), None, now); } #[tracing::instrument(level = "debug", skip_all, fields(%from, tid, method, class, rtt))] @@ -300,8 +285,6 @@ impl Allocation { "`from` and `local` to have the same IP version" ); - self.update_now(now); - if !self.server.matches(from) { return false; } @@ -331,20 +314,20 @@ impl Allocation { let request = self .sent_requests .get(&transaction_id) - .map(|(_, r, _, _, _)| r.attributes().map(display_attr).collect::>()); + .map(|(_, r, _)| r.attributes().map(display_attr).collect::>()); let response = message.attributes().map(display_attr).collect::>(); tracing::warn!(?request, ?response, "Message integrity check failed"); return true; // The message still indicated that it was for this `Allocation`. } - let Some((original_dst, original_request, sent_at, _, _)) = + let Some((original_dst, original_request, backoff)) = self.sent_requests.remove(&transaction_id) else { return false; }; - let rtt = now.duration_since(sent_at); + let rtt = now.duration_since(backoff.start_time()); Span::current().record("rtt", field::debug(rtt)); if tracing::enabled!(tracing::Level::DEBUG) { @@ -394,7 +377,7 @@ impl Allocation { "Request failed, re-authenticating" ); - self.authenticate_and_queue(original_request, None); + self.authenticate_and_queue(original_request, None, now); return true; } @@ -410,6 +393,7 @@ impl Allocation { self.authenticate_and_queue( make_delete_allocation_request(self.software.clone()), None, + now, ); tracing::debug!("Deleting existing allocation to re-sync"); @@ -420,6 +404,7 @@ impl Allocation { self.authenticate_and_queue( make_allocate_request(self.software.clone()), None, + now, ); tracing::debug!("Making new allocation to re-sync"); @@ -430,6 +415,7 @@ impl Allocation { self.authenticate_and_queue( make_allocate_request(self.software.clone()), None, + now, ); tracing::debug!("Making new allocation to re-sync"); @@ -538,9 +524,17 @@ impl Allocation { tracing::debug!(active_socket = %original_dst, "Updating active socket"); if self.has_allocation() { - self.authenticate_and_queue(make_refresh_request(self.software.clone()), None); + self.authenticate_and_queue( + make_refresh_request(self.software.clone()), + None, + now, + ); } else { - self.authenticate_and_queue(make_allocate_request(self.software.clone()), None); + self.authenticate_and_queue( + make_allocate_request(self.software.clone()), + None, + now, + ); } } ALLOCATE => { @@ -593,7 +587,11 @@ impl Allocation { // If we refreshed with a lifetime of 0, we deleted our previous allocation. // Make a new one. if lifetime.lifetime().is_zero() { - self.authenticate_and_queue(make_allocate_request(self.software.clone()), None); + self.authenticate_and_queue( + make_allocate_request(self.software.clone()), + None, + now, + ); return true; } @@ -655,8 +653,6 @@ impl Allocation { #[tracing::instrument(level = "debug", skip_all, fields(active_socket = ?self.active_socket))] pub fn handle_timeout(&mut self, now: Instant) { - self.update_now(now); - if self .allocation_expires_at() .is_some_and(|expires_at| now >= expires_at) @@ -672,22 +668,23 @@ impl Allocation { .as_mut() .and_then(|a| a.handle_timeout(now)) { - self.queue(addr, make_binding_request(self.software.clone()), None); + self.queue(addr, make_binding_request(self.software.clone()), None, now); } } - while let Some(timed_out_request) = - self.sent_requests - .iter() - .find_map(|(id, (_, _, sent_at, backoff, _))| { - (now.duration_since(*sent_at) >= *backoff).then_some(*id) - }) + while let Some(timed_out_request) = self + .sent_requests + .iter() + .find_map(|(id, (_, _, backoff))| (now >= backoff.next_trigger()).then_some(*id)) { - let (dst, request, _, backoff_duration, backoff) = self + let (dst, request, mut backoff) = self .sent_requests .remove(&timed_out_request) .expect("ID is from list"); + backoff.handle_timeout(now); // Must update timeout here to avoid an endless loop. + + let backoff_duration = backoff.interval(); let method = request.method(); tracing::debug!(id = ?request.transaction_id(), %method, %dst, "Request timed out after {backoff_duration:?}, re-sending"); @@ -695,9 +692,9 @@ impl Allocation { let needs_auth = method != BINDING; let queued = if needs_auth { - self.authenticate_and_queue(request, Some(backoff)) + self.authenticate_and_queue(request, Some(backoff), now) } else { - self.queue(dst, request, Some(backoff)) + self.queue(dst, request, Some(backoff), now) }; // If we have an active socket (i.e. successfully sent at least 1 BINDING request) @@ -712,10 +709,14 @@ impl Allocation { } } + for (_, _, backoff) in self.sent_requests.values_mut() { + backoff.handle_timeout(now); + } + if let Some(refresh_at) = self.refresh_allocation_at() { if (now >= refresh_at) && !self.refresh_in_flight() { tracing::debug!("Allocation is due for a refresh"); - self.authenticate_and_queue(make_refresh_request(self.software.clone()), None); + self.authenticate_and_queue(make_refresh_request(self.software.clone()), None, now); } } @@ -731,7 +732,7 @@ impl Allocation { .collect::>(); // Need to allocate here to satisfy borrow-checker. Number of channel refresh messages should be small so this shouldn't be a big impact. for message in channel_refresh_messages { - self.authenticate_and_queue(message, None); + self.authenticate_and_queue(message, None, now); } // TODO: Clean up unused channels @@ -746,15 +747,16 @@ impl Allocation { } pub fn poll_timeout(&self) -> Option { - let mut earliest_timeout = if !self.refresh_in_flight() { + let next_refresh = if !self.refresh_in_flight() { self.refresh_allocation_at() } else { None }; - for (_, (_, _, sent_at, backoff, _)) in self.sent_requests.iter() { - earliest_timeout = earliest(earliest_timeout, Some(*sent_at + *backoff)); - } + let next_timeout = self + .sent_requests + .values() + .map(|(_, _, b)| b.next_trigger()); let next_keepalive = if self.has_allocation() { self.active_socket.map(|a| a.next_binding) @@ -762,7 +764,11 @@ impl Allocation { None }; - earliest(earliest_timeout, next_keepalive) + iter::empty() + .chain(next_refresh) + .chain(next_keepalive) + .chain(next_timeout) + .min() } #[tracing::instrument(level = "debug", skip(self, now), fields(active_socket = ?self.active_socket))] @@ -772,8 +778,6 @@ impl Allocation { return; } - self.update_now(now); - if self .channel_bindings .connected_channel_to_peer(peer, now) @@ -808,6 +812,7 @@ impl Allocation { self.authenticate_and_queue( make_channel_bind_request(peer, channel, self.software.clone()), None, + now, ); } @@ -965,7 +970,7 @@ impl Allocation { } fn channel_binding_in_flight_by_number(&self, channel: u16) -> bool { - self.sent_requests.values().any(|(_, r, _, _, _)| { + self.sent_requests.values().any(|(_, r, _)| { r.method() == CHANNEL_BIND && r.get_attribute::() .is_some_and(|n| n.value() == channel) @@ -976,7 +981,7 @@ impl Allocation { let sent_requests = self .sent_requests .values() - .map(|(_, r, _, _, _)| r) + .map(|(_, r, _)| r) .filter(|message| message.method() == CHANNEL_BIND) .filter_map(|message| message.get_attribute::()) .map(|a| a.address()); @@ -990,13 +995,13 @@ impl Allocation { fn allocate_in_flight(&self) -> bool { self.sent_requests .values() - .any(|(_, r, _, _, _)| r.method() == ALLOCATE) + .any(|(_, r, _)| r.method() == ALLOCATE) } fn refresh_in_flight(&self) -> bool { self.sent_requests .values() - .any(|(_, r, _, _, _)| r.method() == REFRESH) + .any(|(_, r, _)| r.method() == REFRESH) } /// Check whether this allocation is suspended. @@ -1011,12 +1016,15 @@ impl Allocation { no_allocation && nothing_in_flight && nothing_buffered && waiting_on_nothing } - fn send_binding_requests(&mut self) { + fn send_binding_requests(&mut self, now: Instant) { + tracing::debug!(relay_socket = ?self.server, "Sending BINDING requests to pick active socket"); + if let Some(v4) = self.server.as_v4() { self.queue( (*v4).into(), make_binding_request(self.software.clone()), None, + now, ); } if let Some(v6) = self.server.as_v6() { @@ -1024,6 +1032,7 @@ impl Allocation { (*v6).into(), make_binding_request(self.software.clone()), None, + now, ); } } @@ -1033,6 +1042,7 @@ impl Allocation { &mut self, message: Message, backoff: Option, + now: Instant, ) -> bool { let Some(active_socket) = self.active_socket else { tracing::debug!( @@ -1051,7 +1061,7 @@ impl Allocation { }; let authenticated_message = authenticate(message, credentials); - self.queue(active_socket.addr, authenticated_message, backoff) + self.queue(active_socket.addr, authenticated_message, backoff, now) } fn queue( @@ -1059,21 +1069,19 @@ impl Allocation { dst: SocketAddr, message: Message, backoff: Option, + now: Instant, ) -> bool { - let mut backoff = backoff.unwrap_or(backoff::new(self.last_now, REQUEST_TIMEOUT)); - - let Some(duration) = backoff.next_backoff() else { - tracing::debug!( - "Unable to queue {} because we've exceeded its backoffs", - message.method() - ); - return false; - }; - + let backoff = backoff.unwrap_or(backoff::new(now, REQUEST_TIMEOUT)); let id = message.transaction_id(); + if backoff.is_expired(now) { + tracing::debug!(?id, method = %message.method(), %dst, "Backoff expired, giving up"); + + return false; + } + self.sent_requests - .insert(id, (dst, message.clone(), self.last_now, duration, backoff)); + .insert(id, (dst, message.clone(), backoff)); self.buffered_transmits.push_back(Transmit { src: None, dst, @@ -1083,18 +1091,6 @@ impl Allocation { true } - fn update_now(&mut self, now: Instant) { - if now <= self.last_now { - return; - } - - self.last_now = now; - - for (_, _, _, _, backoff) in self.sent_requests.values_mut() { - backoff.clock.now = now; - } - } - #[cfg(test)] fn check_message_integrity(&self, _: &Message) -> bool { true // In order to make the tests simpler, we skip the message integrity check there. @@ -1847,7 +1843,8 @@ mod tests { #[test] fn buffer_channel_bind_requests_until_we_have_allocation() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); assert_eq!(allocate.method(), ALLOCATE); @@ -1870,8 +1867,8 @@ mod tests { #[test] fn does_relay_to_with_bound_channel() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); let channel_bind_msg = allocation.next_message().unwrap(); @@ -1891,8 +1888,8 @@ mod tests { #[test] fn does_not_relay_to_with_unbound_channel() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); let mut buffer = channel_data_packet_buffer(b"foobar"); @@ -1905,8 +1902,8 @@ mod tests { #[test] fn failed_channel_binding_removes_state() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); let channel_bind_msg = allocation.next_message().unwrap(); @@ -1929,8 +1926,8 @@ mod tests { #[test] fn rebinding_existing_channel_send_no_message() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); let channel_bind_msg = allocation.next_message().unwrap(); @@ -1971,8 +1968,8 @@ mod tests { #[test] fn given_no_ip6_allocation_does_not_attempt_to_bind_channel_to_ip6_address() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER2_IP6, Instant::now()); let next_msg = allocation.next_message(); @@ -1983,8 +1980,8 @@ mod tests { #[test] fn given_no_ip4_allocation_does_not_attempt_to_bind_channel_to_ip4_address() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP6]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP6], Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); let next_msg = allocation.next_message(); @@ -1993,7 +1990,8 @@ mod tests { #[test] fn given_only_ip4_allocation_when_binding_channel_to_ip6_does_not_emit_buffered_binding() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); // Attempt to allocate let allocate = allocation.next_message().unwrap(); @@ -2018,7 +2016,8 @@ mod tests { #[test] fn initial_allocate_has_username_realm_and_message_integrity_set() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); @@ -2035,7 +2034,8 @@ mod tests { #[test] fn initial_allocate_is_missing_nonce() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); @@ -2044,7 +2044,8 @@ mod tests { #[test] fn upon_stale_nonce_reauthorizes_using_new_nonce() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2064,7 +2065,8 @@ mod tests { #[test] fn given_a_request_with_nonce_and_we_are_unauthorized_dont_retry() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); // Attempt to authenticate without a nonce let allocate = allocation.next_message().unwrap(); @@ -2089,7 +2091,8 @@ mod tests { #[test] fn returns_new_candidates_on_successful_allocation() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2117,7 +2120,8 @@ mod tests { #[test] fn calling_refresh_with_same_credentials_will_trigger_refresh() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2136,7 +2140,8 @@ mod tests { #[test] fn failed_refresh_will_invalidate_relay_candiates() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2171,7 +2176,8 @@ mod tests { #[test] fn failed_refresh_clears_all_channel_bindings() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2202,7 +2208,8 @@ mod tests { #[test] fn refresh_does_nothing_if_we_dont_have_an_allocation_yet() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let _allocate = allocation.next_message().unwrap(); @@ -2214,7 +2221,8 @@ mod tests { #[test] fn failed_refresh_attempts_to_make_new_allocation() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2234,7 +2242,7 @@ mod tests { #[test] fn allocation_is_refreshed_after_half_its_lifetime() { let mut now = Instant::now(); - let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1); + let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1, now); let allocate = allocation.next_message().unwrap(); @@ -2252,7 +2260,8 @@ mod tests { #[test] fn allocation_is_refreshed_only_once() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( @@ -2269,7 +2278,8 @@ mod tests { #[test] fn when_refreshed_with_no_allocation_after_failed_response_tries_to_allocate() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4(&server_error(&allocate), Instant::now()); @@ -2286,7 +2296,8 @@ mod tests { #[test] fn failed_allocation_clears_buffered_channel_bindings() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); allocation.bind_channel(PEER1, Instant::now()); @@ -2314,8 +2325,8 @@ mod tests { let _guard = firezone_logging::test("debug"); let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6], Instant::now()); allocation.bind_channel(PEER1, Instant::now()); @@ -2343,7 +2354,8 @@ mod tests { #[test] fn dont_buffer_channel_bindings_twice() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); allocation.bind_channel(PEER1, Instant::now()); allocation.bind_channel(PEER1, Instant::now()); @@ -2363,7 +2375,8 @@ mod tests { #[test] fn buffered_channel_bindings_to_different_peers_work() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut allocation = + Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1, Instant::now()); allocation.bind_channel(PEER1, Instant::now()); allocation.bind_channel(PEER2_IP4, Instant::now()); @@ -2387,8 +2400,8 @@ mod tests { #[test] fn dont_send_channel_binding_if_inflight() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER1, Instant::now()); @@ -2403,8 +2416,8 @@ mod tests { #[test] fn send_channel_binding_to_second_peer_if_inflight_for_other() { let mut allocation = Allocation::for_test_ip4(Instant::now()) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, Instant::now()) + .with_allocate_response(&[RELAY_ADDR_IP4], Instant::now()); allocation.bind_channel(PEER1, Instant::now()); @@ -2433,7 +2446,7 @@ mod tests { let _guard = firezone_logging::test("trace"); let mut now = Instant::now(); - let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1); + let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1, now); // Make an allocation { @@ -2472,8 +2485,8 @@ mod tests { fn expires_allocation_invalidates_candidates() { let start = Instant::now(); let mut allocation = Allocation::for_test_ip4(start) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6]); + .with_binding_response(PEER1, start) + .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6], start); let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); @@ -2492,8 +2505,8 @@ mod tests { fn invalid_credentials_invalidates_existing_allocation() { let now = Instant::now(); let mut allocation = Allocation::for_test_ip4(now) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6]); + .with_binding_response(PEER1, now) + .with_allocate_response(&[RELAY_ADDR_IP4, RELAY_ADDR_IP6], now); let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); allocation.credentials.as_mut().unwrap().nonce = Some(Nonce::new("nonce1".to_owned()).unwrap()); // Assume we had a nonce. @@ -2653,7 +2666,6 @@ mod tests { let Some(timeout) = allocation.poll_timeout() else { break; }; - allocation.handle_timeout(timeout); // We expect two transmits. @@ -2673,8 +2685,8 @@ mod tests { let mut now = Instant::now(); let mut allocation = Allocation::for_test_ip4(now) - .with_binding_response(PEER1) - .with_allocate_response(&[RELAY_ADDR_IP4]); + .with_binding_response(PEER1, now) + .with_allocate_response(&[RELAY_ADDR_IP4], now); now += BINDING_INTERVAL; allocation.handle_timeout(now); @@ -2822,16 +2834,16 @@ mod tests { ) } - fn with_binding_response(mut self, srflx_addr: SocketAddr) -> Self { + fn with_binding_response(mut self, srflx_addr: SocketAddr, now: Instant) -> Self { let binding = self.next_message().unwrap(); - self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), self.last_now); + self.handle_test_input_ip4(&binding_response(&binding, srflx_addr), now); self } - fn with_allocate_response(mut self, relay_addrs: &[SocketAddr]) -> Self { + fn with_allocate_response(mut self, relay_addrs: &[SocketAddr], now: Instant) -> Self { let allocate = self.next_message().unwrap(); - self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), self.last_now); + self.handle_test_input_ip4(&allocate_response(&allocate, relay_addrs), now); self } diff --git a/rust/connlib/snownet/src/backoff.rs b/rust/connlib/snownet/src/backoff.rs index 036729d99..f29203353 100644 --- a/rust/connlib/snownet/src/backoff.rs +++ b/rust/connlib/snownet/src/backoff.rs @@ -1,31 +1,51 @@ use std::time::{Duration, Instant}; -pub type ExponentialBackoff = backoff::exponential::ExponentialBackoff; +const MULTIPLIER: f32 = 1.5; +const MAX_ELAPSED_TIME: Duration = Duration::from_secs(8); #[derive(Debug)] -pub struct ManualClock { - pub now: Instant, +pub struct ExponentialBackoff { + start_time: Instant, + next_trigger: Instant, + interval: Duration, } -impl backoff::Clock for ManualClock { - fn now(&self) -> Instant { - self.now +impl ExponentialBackoff { + pub(crate) fn handle_timeout(&mut self, now: Instant) { + if self.is_expired(now) { + return; + } + + if now < self.next_trigger { + return; + } + + self.interval = Duration::from_secs_f32(self.interval.as_secs_f32() * MULTIPLIER); + self.next_trigger += self.interval; + } + + pub(crate) fn next_trigger(&self) -> Instant { + self.next_trigger + } + + pub(crate) fn is_expired(&self, at: Instant) -> bool { + at >= self.start_time + MAX_ELAPSED_TIME + } + + pub(crate) fn interval(&self) -> Duration { + self.interval + } + + pub(crate) fn start_time(&self) -> Instant { + self.start_time } } -pub fn new( - now: Instant, - initial_interval: Duration, -) -> backoff::exponential::ExponentialBackoff { +pub fn new(now: Instant, interval: Duration) -> ExponentialBackoff { ExponentialBackoff { - current_interval: initial_interval, - initial_interval, - randomization_factor: 0., - multiplier: backoff::default::MULTIPLIER, - max_interval: Duration::from_millis(backoff::default::MAX_INTERVAL_MILLIS), + interval, start_time: now, - max_elapsed_time: Some(Duration::from_secs(10)), - clock: ManualClock { now }, + next_trigger: now + interval, } } @@ -45,3 +65,34 @@ pub fn steps(start: Instant) -> [Instant; 4] { start + secs(1.0 + 1.5 + 2.25 + 3.375), ] } + +#[cfg(test)] +mod tests { + use super::*; + use std::{collections::BTreeSet, iter}; + + #[test] + fn backoff_steps() { + let mut now = Instant::now(); + + let steps = Vec::from_iter( + iter::from_fn({ + let mut backoff = super::new(now, Duration::from_secs(1)); + + move || { + if backoff.is_expired(now) { + return None; + } + + now += Duration::from_millis(100); // Purposely updating more often than the interval. + backoff.handle_timeout(now); + + Some(backoff.next_trigger()) + } + }) + .collect::>(), + ); + + assert_eq!(&steps, &super::steps(now)); + } +}