diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 76245f315..5c223482d 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -58,7 +58,7 @@ pub struct Allocation { timed_out_requests: RingBuffer, channel_bindings: ChannelBindings, - buffered_channel_bindings: BufferedChannelBindings, + buffered_channel_bindings: RingBuffer, last_now: Instant, @@ -113,7 +113,7 @@ impl Allocation { allocation_lifetime: Default::default(), channel_bindings: Default::default(), last_now: now, - buffered_channel_bindings: Default::default(), + buffered_channel_bindings: RingBuffer::new(100), backoff: backoff::new(now, REQUEST_TIMEOUT), timed_out_requests: RingBuffer::new(100), }; @@ -299,6 +299,11 @@ impl Allocation { .attributes() .find_map(relay_candidate(|s| s.is_ipv6())); + if maybe_ip4_relay_candidate.is_none() && maybe_ip6_relay_candidate.is_none() { + tracing::warn!("Relay sent a successful allocate response without addresses"); + return true; + } + self.allocation_lifetime = Some((now, lifetime)); update_candidate( maybe_srflx_candidate, @@ -324,18 +329,12 @@ impl Allocation { "Updated candidates of allocation" ); - while let Some(buffered) = self.buffered_channel_bindings.pop_front() { - let Some(peer) = buffered.get_attribute::() else { - debug_assert!(false, "channel binding must have peer address"); - continue; - }; - - if !self.can_relay_to(peer.address()) { - tracing::debug!("Allocation cannot relay to this IP version"); - continue; - } - - self.authenticate_and_queue(buffered); + while let Some(peer) = self.buffered_channel_bindings.pop() { + debug_assert!( + self.has_allocation(), + "We just received a successful allocation response" + ); + self.bind_channel(peer, now); } } REFRESH => { @@ -501,17 +500,10 @@ impl Allocation { return; } - let Some(channel) = self.channel_bindings.new_channel_to_peer(peer, now) else { - tracing::warn!("All channels are exhausted"); - return; - }; - - let msg = make_channel_bind_request(peer, channel); - if !self.has_allocation() { tracing::debug!("No allocation yet, buffering channel binding"); - self.buffered_channel_bindings.push_back(msg); + self.buffered_channel_bindings.push(peer); return; } @@ -520,7 +512,12 @@ impl Allocation { return; } - self.authenticate_and_queue(msg); + let Some(channel) = self.channel_bindings.new_channel_to_peer(peer, now) else { + tracing::warn!("All channels are exhausted"); + return; + }; + + self.authenticate_and_queue(make_channel_bind_request(peer, channel)); } pub fn encode_to_slice( @@ -627,17 +624,18 @@ impl Allocation { } fn channel_binding_in_flight_by_peer(&self, peer: SocketAddr) -> bool { - let sent_requests = self.sent_requests.values().map(|(r, _, _)| r); - let buffered = self.buffered_channel_bindings.inner.iter(); + let sent_requests = self + .sent_requests + .values() + .map(|(r, _, _)| r) + .filter(|message| message.method() == CHANNEL_BIND) + .filter_map(|message| message.get_attribute::()) + .map(|a| a.address()); + let buffered = self.buffered_channel_bindings.iter().copied(); - sent_requests.chain(buffered).any(|message| { - let is_binding = message.method() == CHANNEL_BIND; - let is_for_peer = message - .get_attribute::() - .is_some_and(|n| n.address() == peer); - - is_binding && is_for_peer - }) + sent_requests + .chain(buffered) + .any(|buffered| buffered == peer) } fn allocate_in_flight(&self) -> bool { @@ -1050,39 +1048,6 @@ impl Channel { } } -#[derive(Debug, Default)] -struct BufferedChannelBindings { - inner: VecDeque>, -} - -impl BufferedChannelBindings { - /// Adds a new `CHANNEL-BIND` message to this buffer. - /// - /// The buffer has a fixed size of 10 to avoid unbounded memory growth. - /// All prior messages are cleared once we outgrow the buffer. - /// Very likely, we buffer `CHANNEL-BIND` messages only for a brief period of time. - /// However, it might also happen that we can only re-connect to a TURN server after an extended period of downtime. - /// Chances are that we don't need any of the old channels any more, and that the new ones are much more relevant. - fn push_back(&mut self, msg: Message) { - debug_assert_eq!(msg.method(), CHANNEL_BIND); - - if self.inner.len() == 10 { - tracing::debug!("Clearing buffered channel-data messages"); - self.inner.clear() - } - - self.inner.push_back(msg); - } - - fn pop_front(&mut self) -> Option> { - self.inner.pop_front() - } - - fn clear(&mut self) { - self.inner.clear() - } -} - #[cfg(test)] mod tests { use super::*; @@ -1391,28 +1356,6 @@ mod tests { assert!(expected_backoffs.is_empty()) } - #[test] - fn discards_old_channel_bindings_once_we_outgrow_buffer() { - let mut buffered_channel_bindings = BufferedChannelBindings::default(); - - for c in 0..11 { - buffered_channel_bindings.push_back(make_channel_bind_request( - PEER1, - ChannelBindings::FIRST_CHANNEL + c, - )); - } - - let msg = buffered_channel_bindings.pop_front().unwrap(); - assert!( - buffered_channel_bindings.pop_front().is_none(), - "no more messages" - ); - assert_eq!( - msg.get_attribute::().unwrap().value(), - ChannelBindings::FIRST_CHANNEL + 10 - ); - } - #[test] fn given_no_ip6_allocation_does_not_attempt_to_bind_channel_to_ip6_address() { let mut allocation = @@ -1801,10 +1744,10 @@ mod tests { let channel_bind_peer_2 = allocation.next_message().unwrap(); assert_eq!(channel_bind_peer_1.method(), CHANNEL_BIND); - assert_eq!(peer_address(&channel_bind_peer_1), PEER1); + assert_eq!(peer_address(&channel_bind_peer_1), PEER2_IP4); assert_eq!(channel_bind_peer_2.method(), CHANNEL_BIND); - assert_eq!(peer_address(&channel_bind_peer_2), PEER2_IP4); + assert_eq!(peer_address(&channel_bind_peer_2), PEER1); } #[test] diff --git a/rust/connlib/snownet/src/ringbuffer.rs b/rust/connlib/snownet/src/ringbuffer.rs index d131c45b5..b4b026760 100644 --- a/rust/connlib/snownet/src/ringbuffer.rs +++ b/rust/connlib/snownet/src/ringbuffer.rs @@ -24,6 +24,18 @@ impl RingBuffer { initial_len != self.buffer.len() } + pub fn pop(&mut self) -> Option { + self.buffer.pop() + } + + pub fn clear(&mut self) { + self.buffer.clear(); + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.buffer.iter() + } + #[cfg(test)] fn inner(&self) -> &[T] { self.buffer.as_slice()