diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 3a960b2f0..e80669f51 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -1,4 +1,9 @@ -use crate::node::Transmit; +use crate::{ + backoff::{self, ExponentialBackoff}, + node::Transmit, + utils::earliest, +}; +use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; use rand::random; use std::{ @@ -45,12 +50,13 @@ pub struct Allocation { buffered_transmits: VecDeque>, new_candidates: VecDeque, - sent_requests: HashMap, Instant)>, + backoff: ExponentialBackoff, + sent_requests: HashMap, Instant, Duration)>, channel_bindings: ChannelBindings, - buffered_channel_bindings: VecDeque>, + buffered_channel_bindings: BufferedChannelBindings, - last_now: Option, + last_now: Instant, username: Username, password: String, @@ -59,8 +65,14 @@ pub struct Allocation { } impl Allocation { - pub fn new(server: SocketAddr, username: Username, password: String, realm: Realm) -> Self { - Self { + pub fn new( + server: SocketAddr, + username: Username, + password: String, + realm: Realm, + now: Instant, + ) -> Self { + let mut allocation = Self { server, last_srflx_candidate: Default::default(), ip4_allocation: Default::default(), @@ -74,9 +86,16 @@ impl Allocation { nonce: Default::default(), allocation_lifetime: Default::default(), channel_bindings: Default::default(), - last_now: Default::default(), + last_now: now, buffered_channel_bindings: Default::default(), - } + backoff: backoff::new(now, REQUEST_TIMEOUT), + }; + + tracing::debug!(%server, "Requesting new allocation"); + + allocation.authenticate_and_queue(make_allocate_request()); + + allocation } pub fn current_candidates(&self) -> impl Iterator { @@ -96,9 +115,7 @@ impl Allocation { packet: &[u8], now: Instant, ) -> bool { - if Some(now) > self.last_now { - self.last_now = Some(now); - } + self.update_now(now); if from != self.server { return false; @@ -108,12 +125,14 @@ impl Allocation { return false; }; - let Some((original_request, sent_at)) = + let Some((original_request, sent_at, _)) = self.sent_requests.remove(&message.transaction_id()) else { return false; }; + self.backoff.reset(); + let rtt = now.duration_since(sent_at); tracing::debug!(id = ?original_request.transaction_id(), method = %original_request.method(), ?rtt); @@ -137,7 +156,7 @@ impl Allocation { "Request failed, re-authenticating" ); - self.authenticate_and_queue(original_request, now); + self.authenticate_and_queue(original_request); return true; } @@ -225,7 +244,7 @@ impl Allocation { ); while let Some(buffered) = self.buffered_channel_bindings.pop_front() { - self.authenticate_and_queue(buffered, now); + self.authenticate_and_queue(buffered); } } REFRESH => { @@ -295,37 +314,29 @@ impl Allocation { } pub fn handle_timeout(&mut self, now: Instant) { - if Some(now) > self.last_now { - self.last_now = Some(now); - } - - if !self.has_allocation() && !self.allocate_in_flight() { - tracing::debug!(server = %self.server, "Request new allocation"); - - self.authenticate_and_queue(make_allocate_request(), now); - } + self.update_now(now); while let Some(timed_out_request) = - self.sent_requests.iter().find_map(|(id, (_, sent_at))| { - (now.duration_since(*sent_at) >= REQUEST_TIMEOUT).then_some(*id) - }) + self.sent_requests + .iter() + .find_map(|(id, (_, sent_at, backoff))| { + (now.duration_since(*sent_at) >= *backoff).then_some(*id) + }) { - let (request, _) = self + let (request, _, _) = self .sent_requests .remove(&timed_out_request) .expect("ID is from list"); tracing::debug!(id = ?request.transaction_id(), method = %request.method(), "Request timed out, re-sending"); - self.authenticate_and_queue(request, now); + self.authenticate_and_queue(request); } - if let Some((received_at, lifetime)) = self.allocation_lifetime { - let refresh_after = lifetime / 2; - - if now > received_at + refresh_after { - tracing::debug!("Allocation is at 50% of its lifetime, refreshing"); - self.authenticate_and_queue(make_refresh_request(), now); + if let Some(refresh_at) = self.refresh_allocation_at() { + if now > refresh_at { + tracing::debug!("Allocation is due for a refresh"); + self.authenticate_and_queue(make_refresh_request()); } } @@ -336,7 +347,7 @@ impl Allocation { .collect::>(); // Need to allocate here to satisfy borrow-checker. Number of channel refresh messages should be small so this shouldn't be a big impact. for message in refresh_messages { - self.authenticate_and_queue(message, now); + self.authenticate_and_queue(message); } // TODO: Clean up unused channels @@ -350,11 +361,19 @@ impl Allocation { self.buffered_transmits.pop_front() } - // pub fn poll_timeout(&self) -> Option { - // None // TODO: Implement this. - // } + pub fn poll_timeout(&self) -> Option { + let mut earliest_timeout = self.refresh_allocation_at(); + + for (_, (_, sent_at, backoff)) in self.sent_requests.iter() { + earliest_timeout = earliest(earliest_timeout, Some(*sent_at + *backoff)); + } + + earliest_timeout + } pub fn bind_channel(&mut self, peer: SocketAddr, now: Instant) { + self.update_now(now); + if self.channel_bindings.channel_to_peer(peer, now).is_some() { tracing::debug!(relay = %self.server, %peer, "Already got a channel"); return; @@ -374,7 +393,7 @@ impl Allocation { return; } - self.authenticate_and_queue(msg, now); + self.authenticate_and_queue(msg); } pub fn encode_to_slice( @@ -403,18 +422,20 @@ impl Allocation { Some(channel_data) } + fn refresh_allocation_at(&self) -> Option { + let (received_at, lifetime) = self.allocation_lifetime?; + + let refresh_after = lifetime / 2; + + Some(received_at + refresh_after) + } + fn has_allocation(&self) -> bool { self.ip4_allocation.is_some() || self.ip6_allocation.is_some() } - fn allocate_in_flight(&self) -> bool { - self.sent_requests - .values() - .any(|(r, _)| r.method() == ALLOCATE) - } - fn channel_binding_in_flight(&self, channel: u16) -> bool { - self.sent_requests.values().any(|(r, _)| { + self.sent_requests.values().any(|(r, _, _)| { r.method() == BINDING && r.get_attribute::() .is_some_and(|n| n.value() == channel) @@ -455,18 +476,33 @@ impl Allocation { message } - fn authenticate_and_queue(&mut self, message: Message, now: Instant) { + fn authenticate_and_queue(&mut self, message: Message) { + let Some(backoff) = self.backoff.next_backoff() else { + tracing::warn!( + "Unable to queue {} because we've exceeded our backoffs", + message.method() + ); + return; + }; + let authenticated_message = self.authenticate(message); let id = authenticated_message.transaction_id(); self.sent_requests - .insert(id, (authenticated_message.clone(), now)); + .insert(id, (authenticated_message.clone(), self.last_now, backoff)); self.buffered_transmits.push_back(Transmit { src: None, dst: self.server, payload: encode(authenticated_message).into(), }); } + + fn update_now(&mut self, now: Instant) { + if now > self.last_now { + self.last_now = now; + self.backoff.clock.now = now; + } + } } fn update_candidate( @@ -769,6 +805,35 @@ impl Channel { } } +#[derive(Debug, Default)] +struct BufferedChannelBindings { + inner: VecDeque>, +} + +impl BufferedChannelBindings { + /// Adds a new `CHANNEL-BIND` message to this buffer. + /// + /// The buffer has a fixed size of 10 to avoid unbounded memory growth. + /// All prior messages are cleared once we outgrow the buffer. + /// Very likely, we buffer `CHANNEL-BIND` messages only for a brief period of time. + /// However, it might also happen that we can only re-connect to a TURN server after an extended period of downtime. + /// Chances are that we don't need any of the old channels any more, and that the new ones are much more relevant. + fn push_back(&mut self, msg: Message) { + debug_assert_eq!(msg.method(), CHANNEL_BIND); + + if self.inner.len() == 10 { + tracing::debug!("Clearing buffered channel-data messages"); + self.inner.clear() + } + + self.inner.push_back(msg); + } + + fn pop_front(&mut self) -> Option> { + self.inner.pop_front() + } +} + #[cfg(test)] mod tests { use super::*; @@ -954,15 +1019,19 @@ mod tests { Username::new("foobar".to_owned()).unwrap(), "baz".to_owned(), Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), ); + let allocate = next_stun_message(&mut allocation).unwrap(); + assert_eq!(allocate.method(), ALLOCATE); + allocation.bind_channel(PEER1, Instant::now()); assert!( next_stun_message(&mut allocation).is_none(), "no messages to be sent if we don't have an allocation" ); - make_allocation(&mut allocation, PEER1); + make_allocation(&mut allocation, allocate.transaction_id(), PEER1); let message = next_stun_message(&mut allocation).unwrap(); assert_eq!(message.method(), CHANNEL_BIND); @@ -975,9 +1044,11 @@ mod tests { Username::new("foobar".to_owned()).unwrap(), "baz".to_owned(), Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), ); - make_allocation(&mut allocation, PEER1); + let allocate = next_stun_message(&mut allocation).unwrap(); + make_allocation(&mut allocation, allocate.transaction_id(), PEER1); allocation.bind_channel(PEER2, Instant::now()); let message = allocation.encode_to_vec(PEER2, b"foobar", Instant::now()); @@ -992,9 +1063,11 @@ mod tests { Username::new("foobar".to_owned()).unwrap(), "baz".to_owned(), Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), ); - make_allocation(&mut allocation, PEER1); + let allocate = next_stun_message(&mut allocation).unwrap(); + make_allocation(&mut allocation, allocate.transaction_id(), PEER1); allocation.bind_channel(PEER2, Instant::now()); let channel_bind_msg = next_stun_message(&mut allocation).unwrap(); @@ -1023,9 +1096,11 @@ mod tests { Username::new("foobar".to_owned()).unwrap(), "baz".to_owned(), Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), ); - make_allocation(&mut allocation, PEER1); + let allocate = next_stun_message(&mut allocation).unwrap(); + make_allocation(&mut allocation, allocate.transaction_id(), PEER1); allocation.bind_channel(PEER2, Instant::now()); let channel_bind_msg = next_stun_message(&mut allocation).unwrap(); @@ -1042,6 +1117,57 @@ mod tests { assert!(next_msg.is_none()) } + #[test] + fn retries_requests_using_backoff_and_gives_up_eventually() { + let start = Instant::now(); + let mut allocation = Allocation::new( + RELAY, + Username::new("foobar".to_owned()).unwrap(), + "baz".to_owned(), + Realm::new("firezone".to_owned()).unwrap(), + start, + ); + + let mut expected_backoffs = VecDeque::from(backoff::steps(start)); + + loop { + let Some(timeout) = allocation.poll_timeout() else { + break; + }; + + assert_eq!(expected_backoffs.pop_front().unwrap(), timeout); + + assert!(allocation.poll_transmit().is_some()); + assert!(allocation.poll_transmit().is_none()); + + allocation.handle_timeout(timeout); + } + + assert!(expected_backoffs.is_empty()) + } + + #[test] + fn discards_old_channel_bindings_once_we_outgrow_buffer() { + let mut buffered_channel_bindings = BufferedChannelBindings::default(); + + for c in 0..11 { + buffered_channel_bindings.push_back(make_channel_bind_request( + PEER1, + ChannelBindings::FIRST_CHANNEL + c, + )); + } + + let msg = buffered_channel_bindings.pop_front().unwrap(); + assert!( + buffered_channel_bindings.pop_front().is_none(), + "no more messages" + ); + assert_eq!( + msg.get_attribute::().unwrap().value(), + ChannelBindings::FIRST_CHANNEL + 10 + ); + } + fn ch(peer: SocketAddr, now: Instant) -> Channel { Channel { peer, @@ -1057,13 +1183,11 @@ mod tests { Some(decode(&transmit.payload).unwrap().unwrap()) } - fn make_allocation(allocation: &mut Allocation, local: SocketAddr) { - allocation.handle_timeout(Instant::now()); - let message = next_stun_message(allocation).unwrap(); + fn make_allocation(allocation: &mut Allocation, allocate_id: TransactionId, local: SocketAddr) { allocation.handle_input( RELAY, local, - &encode(allocate_response(message.transaction_id())), + &encode(allocate_response(allocate_id)), Instant::now(), ); } diff --git a/rust/connlib/snownet/src/backoff.rs b/rust/connlib/snownet/src/backoff.rs index c822e3e7f..8a4d06d81 100644 --- a/rust/connlib/snownet/src/backoff.rs +++ b/rust/connlib/snownet/src/backoff.rs @@ -30,3 +30,35 @@ pub fn new( clock: ManualClock { now }, } } + +/// Calculates our backoff times, starting from the given [`Instant`]. +/// +/// The current strategy is multiplying the previous interval by 1.5 and adding them up. +#[cfg(test)] +pub fn steps(start: Instant) -> [Instant; 19] { + fn secs(secs: f64) -> Duration { + Duration::from_micros((secs * 1_000_000.0) as u64) + } + + [ + start + secs(5.0), + start + secs(5.0 + 7.5), + start + secs(5.0 + 7.5 + 11.25), + start + secs(5.0 + 7.5 + 11.25 + 16.875), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 1.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 2.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 3.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 4.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 5.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 6.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 7.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 8.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 9.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 10.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 11.0), + start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 12.0), + ] +} diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 547c9bd87..29d25da12 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -8,6 +8,7 @@ mod info; mod ip_packet; mod node; mod stun_binding; +mod utils; pub use info::ConnectionInfo; pub use ip_packet::{IpPacket, MutableIpPacket}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 4898ad2d3..67b327184 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -23,6 +23,7 @@ use crate::allocation::Allocation; use crate::index::IndexLfsr; use crate::info::ConnectionInfo; use crate::stun_binding::StunBinding; +use crate::utils::earliest; use crate::{IpPacket, MutableIpPacket}; use boringtun::noise::errors::WireGuardError; use std::borrow::Cow; @@ -450,6 +451,9 @@ where for b in self.bindings.values_mut() { connection_timeout = earliest(connection_timeout, b.poll_timeout()); } + for a in self.allocations.values_mut() { + connection_timeout = earliest(connection_timeout, a.poll_timeout()); + } earliest(connection_timeout, self.next_rate_limiter_reset) } @@ -809,7 +813,7 @@ where self.allocations.insert( *server, - Allocation::new(*server, username, password.clone(), realm), + Allocation::new(*server, username, password.clone(), realm, self.last_now), ); } } @@ -1168,12 +1172,3 @@ impl Connection { } } } - -fn earliest(left: Option, right: Option) -> Option { - match (left, right) { - (None, None) => None, - (Some(left), Some(right)) => Some(std::cmp::min(left, right)), - (Some(left), None) => Some(left), - (None, Some(right)) => Some(right), - } -} diff --git a/rust/connlib/snownet/src/stun_binding.rs b/rust/connlib/snownet/src/stun_binding.rs index a285bafb4..070757e6c 100644 --- a/rust/connlib/snownet/src/stun_binding.rs +++ b/rust/connlib/snownet/src/stun_binding.rs @@ -315,32 +315,7 @@ mod tests { let start = Instant::now(); let mut stun_binding = StunBinding::new(SERVER1, start); - fn secs(secs: f64) -> Duration { - Duration::from_micros((secs * 1_000_000.0) as u64) - } - - // The backoff strategy is to increment the previous interval by 1.5 - let mut expected_backoffs = VecDeque::from([ - start + secs(5.0), - start + secs(5.0 + 7.5), - start + secs(5.0 + 7.5 + 11.25), - start + secs(5.0 + 7.5 + 11.25 + 16.875), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 1.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 2.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 3.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 4.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 5.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 6.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 7.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 8.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 9.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 10.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 11.0), - start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 12.0), - ]); + let mut expected_backoffs = VecDeque::from(backoff::steps(start)); loop { let Some(timeout) = stun_binding.poll_timeout() else { diff --git a/rust/connlib/snownet/src/utils.rs b/rust/connlib/snownet/src/utils.rs new file mode 100644 index 000000000..b26639545 --- /dev/null +++ b/rust/connlib/snownet/src/utils.rs @@ -0,0 +1,10 @@ +use std::time::Instant; + +pub fn earliest(left: Option, right: Option) -> Option { + match (left, right) { + (None, None) => None, + (Some(left), Some(right)) => Some(std::cmp::min(left, right)), + (Some(left), None) => Some(left), + (None, Some(right)) => Some(right), + } +}