diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 61ab58119..faaa4cfe1 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -1,6 +1,6 @@ use crate::{ backoff::{self, ExponentialBackoff}, - node::{CandidateEvent, SessionId, Transmit}, + node::{SessionId, Transmit}, ringbuffer::RingBuffer, utils::earliest, EncryptedPacket, @@ -70,7 +70,7 @@ pub struct Allocation { allocation_lifetime: Option<(Instant, Duration)>, buffered_transmits: VecDeque>, - events: VecDeque, + events: VecDeque, sent_requests: BTreeMap< TransactionId, @@ -91,6 +91,12 @@ pub struct Allocation { credentials: Option, } +#[derive(Debug, PartialEq)] +pub(crate) enum Event { + New(Candidate), + Invalid(Candidate), +} + #[derive(Debug, Clone)] struct Credentials { username: Username, @@ -227,15 +233,10 @@ impl Allocation { allocation } - pub fn current_candidates(&self) -> impl Iterator { - [ - self.ip4_srflx_candidate.clone(), - self.ip6_srflx_candidate.clone(), - self.ip4_allocation.clone(), - self.ip6_allocation.clone(), - ] - .into_iter() - .flatten() + pub fn current_relay_candidates(&self) -> impl Iterator { + [self.ip4_allocation.clone(), self.ip6_allocation.clone()] + .into_iter() + .flatten() } /// Refresh this allocation. @@ -654,7 +655,7 @@ impl Allocation { // TODO: Clean up unused channels } - pub fn poll_event(&mut self) -> Option { + pub fn poll_event(&mut self) -> Option { self.events.pop_front() } @@ -827,11 +828,11 @@ impl Allocation { tracing::info!(active_socket = ?self.active_socket, "Invalidating allocation"); if let Some(candidate) = self.ip4_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) + self.events.push_back(Event::Invalid(candidate)) } if let Some(candidate) = self.ip6_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) + self.events.push_back(Event::Invalid(candidate)) } self.channel_bindings.clear(); @@ -1047,17 +1048,17 @@ fn authenticate(message: Message, credentials: &Credentials) -> Messa fn update_candidate( maybe_new: Option, maybe_current: &mut Option, - events: &mut VecDeque, + events: &mut VecDeque, ) { match (maybe_new, &maybe_current) { (Some(new), Some(current)) if &new != current => { - events.push_back(CandidateEvent::New(new.clone())); - events.push_back(CandidateEvent::Invalid(current.clone())); + events.push_back(Event::New(new.clone())); + events.push_back(Event::Invalid(current.clone())); *maybe_current = Some(new); } (Some(new), None) => { *maybe_current = Some(new.clone()); - events.push_back(CandidateEvent::New(new)); + events.push_back(Event::New(new)); } _ => {} } @@ -1926,14 +1927,14 @@ mod tests { let next_event = allocation.poll_event(); assert_eq!( next_event, - Some(CandidateEvent::New( + Some(Event::New( Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap() )) ); let next_event = allocation.poll_event(); assert_eq!( next_event, - Some(CandidateEvent::New( + Some(Event::New( Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap() )) ); @@ -1978,21 +1979,20 @@ mod tests { assert_eq!( allocation.poll_event(), - Some(CandidateEvent::Invalid( + Some(Event::Invalid( Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap() )) ); assert_eq!( allocation.poll_event(), - Some(CandidateEvent::Invalid( + Some(Event::Invalid( Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap() )) ); assert!(allocation.poll_event().is_none()); assert_eq!( - allocation.current_candidates().collect::>(), - vec![Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap()], - "server-reflexive candidate should still be valid after refresh" + allocation.current_relay_candidates().collect::>(), + vec![], ) } @@ -2310,8 +2310,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ) } @@ -2330,8 +2330,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ) } @@ -2362,8 +2362,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ); assert_eq!( @@ -2451,10 +2451,10 @@ mod tests { assert_eq!( events, vec![ - CandidateEvent::New( + Event::New( Candidate::server_reflexive(PEER2_IP4, PEER2_IP4, Protocol::Udp).unwrap() ), - CandidateEvent::New( + Event::New( Candidate::server_reflexive(PEER2_IP6, PEER2_IP6, Protocol::Udp).unwrap() ) ] diff --git a/rust/connlib/snownet/src/candidate_set.rs b/rust/connlib/snownet/src/candidate_set.rs new file mode 100644 index 000000000..ad07e2eb9 --- /dev/null +++ b/rust/connlib/snownet/src/candidate_set.rs @@ -0,0 +1,28 @@ +use std::collections::HashSet; + +use itertools::Itertools; +use str0m::Candidate; + +/// Custom "set" implementation for [`Candidate`]s based on a [`HashSet`] with an enforced ordering when iterating. +#[derive(Debug, Default)] +pub struct CandidateSet { + inner: HashSet, +} + +impl CandidateSet { + pub fn insert(&mut self, c: Candidate) -> bool { + self.inner.insert(c) + } + + pub fn clear(&mut self) { + self.inner.clear() + } + + #[expect( + clippy::disallowed_methods, + reason = "We are guaranteeing a stable ordering" + )] + pub fn iter(&self) -> impl Iterator { + self.inner.iter().sorted_by_key(|c| c.prio()) + } +} diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index a0aa1dc12..aeeb1abbd 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -2,6 +2,7 @@ mod allocation; mod backoff; +mod candidate_set; mod channel_data; mod index; mod node; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 7f1a2f51a..35fe00421 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -1,4 +1,5 @@ -use crate::allocation::{Allocation, RelaySocket, Socket}; +use crate::allocation::{self, Allocation, RelaySocket, Socket}; +use crate::candidate_set::CandidateSet; use crate::index::IndexLfsr; use crate::ringbuffer::RingBuffer; use crate::stats::{ConnectionStats, NodeStats}; @@ -13,7 +14,6 @@ use hex_display::HexDisplayExt; use ip_packet::{ ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf, MAX_DATAGRAM_PAYLOAD, }; -use itertools::Itertools as _; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, Rng, SeedableRng}; @@ -114,7 +114,8 @@ pub struct Node { index: IndexLfsr, rate_limiter: Arc, - host_candidates: Vec, // `Candidate` doesn't implement `PartialOrd` so we cannot use a `BTreeSet`. Linear search is okay because we expect this vec to be <100 elements + /// Host and server-reflexive candidates that are shared between all connections. + shared_candidates: CandidateSet, buffered_transmits: VecDeque>, next_rate_limiter_reset: Option, @@ -168,7 +169,7 @@ where mode: T::new(), index: IndexLfsr::default(), rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)), - host_candidates: Default::default(), + shared_candidates: Default::default(), buffered_transmits: VecDeque::default(), next_rate_limiter_reset: None, pending_events: VecDeque::default(), @@ -205,7 +206,7 @@ where self.pending_events.extend(closed_connections); - self.host_candidates.clear(); + self.shared_candidates.clear(); self.connections.clear(); self.buffered_transmits.clear(); @@ -706,9 +707,7 @@ where agent.handle_timeout(now); if self.allocations.is_empty() { - tracing::warn!( - "No TURN servers connected; connection will very likely fail to establish" - ); + tracing::warn!("No TURN servers connected; connection may fail to establish"); } Connection { @@ -744,12 +743,10 @@ where fn add_local_as_host_candidate(&mut self, local: SocketAddr) -> Result<(), Error> { let host_candidate = Candidate::host(local, Protocol::Udp)?; - if self.host_candidates.contains(&host_candidate) { + if self.shared_candidates.insert(host_candidate.clone()) { return Ok(()); } - self.host_candidates.push(host_candidate.clone()); - for (cid, agent, _span) in self.connections.agents_mut() { add_local_candidate(cid, agent, host_candidate.clone(), &mut self.pending_events); } @@ -907,20 +904,18 @@ where tracing::trace!(%rid, ?event); match event { - CandidateEvent::New(candidate) + allocation::Event::New(candidate) if candidate.kind() == CandidateKind::ServerReflexive => { - for (cid, agent, _span) in self.connections.agents_mut() { - add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events) - } + self.shared_candidates.insert(candidate); } - CandidateEvent::New(candidate) => { + allocation::Event::New(candidate) => { for (cid, agent, _span) in self.connections.connecting_agents_by_relay_mut(rid) { add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events) } } - CandidateEvent::Invalid(candidate) => { + allocation::Event::Invalid(candidate) => { for (cid, agent, _span) in self.connections.agents_mut() { remove_local_candidate(cid, agent, &candidate, &mut self.pending_events); } @@ -1114,17 +1109,7 @@ where selected_relay: Option, agent: &mut IceAgent, ) { - for candidate in self.host_candidates.iter().cloned() { - add_local_candidate(connection, agent, candidate, &mut self.pending_events); - } - - for candidate in self - .allocations - .values() - .flat_map(|a| a.current_candidates()) - .filter(|c| c.kind() == CandidateKind::ServerReflexive) - .unique() - { + for candidate in self.shared_candidates.iter().cloned() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } @@ -1138,10 +1123,7 @@ where return; }; - for candidate in allocation - .current_candidates() - .filter(|c| c.kind() == CandidateKind::Relayed) - { + for candidate in allocation.current_relay_candidates() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } } @@ -1406,10 +1388,7 @@ fn invalidate_allocation_candidates( RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { for (cid, agent, _guard) in connections.agents_mut() { - for candidate in allocation - .current_candidates() - .filter(|c| c.kind() == CandidateKind::Relayed) - { + for candidate in allocation.current_relay_candidates() { remove_local_candidate(cid, agent, &candidate, pending_events); } } @@ -1576,12 +1555,6 @@ impl<'a> Transmit<'a> { } } -#[derive(Debug, PartialEq)] -pub(crate) enum CandidateEvent { - New(Candidate), - Invalid(Candidate), -} - struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>,