From 046b9e0cd4ed89abb194efa8b8ea630ee926380c Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 29 Oct 2024 03:57:41 +1100 Subject: [PATCH] refactor(connlib): track srvflx candidates separately (#7163) As part of maintaining an allocation, we also perform STUN with our relays to discover our server-reflexive address. At the moment, these candidates are scoped to an `Allocation`. This is unnecessarily restrictive. Similar to host candidates, server-reflexive candidate entirely depend on the socket you send data from and are thus independent of the allocation's state. During normal operation, this doesn't really matter because all relay traffic is sent through the same sockets so all `Allocation`s end up with the same server-reflexive candidates. Where this does matter is when we disconnect from relay's for one reason or another (for example: #7162). The fact that all but host-candidates are scoped to `Allocation`s means that without `Allocation`s, we cannot make any new connections, not even direct ones. This is unnecessarily restrictive and causes bugs within `Allocation` to have a bigger blast radius than necessary. With this PR, we keep server-reflexive candidates in the same set as host candidates. This allows us to at least establish direct connections in case something is wrong with the relays or our state tracking of relays on the client side. --- rust/connlib/snownet/src/allocation.rs | 66 +++++++++++------------ rust/connlib/snownet/src/candidate_set.rs | 28 ++++++++++ rust/connlib/snownet/src/lib.rs | 1 + rust/connlib/snownet/src/node.rs | 57 ++++++-------------- 4 files changed, 77 insertions(+), 75 deletions(-) create mode 100644 rust/connlib/snownet/src/candidate_set.rs diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 61ab58119..faaa4cfe1 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -1,6 +1,6 @@ use crate::{ backoff::{self, ExponentialBackoff}, - node::{CandidateEvent, SessionId, Transmit}, + node::{SessionId, Transmit}, ringbuffer::RingBuffer, utils::earliest, EncryptedPacket, @@ -70,7 +70,7 @@ pub struct Allocation { allocation_lifetime: Option<(Instant, Duration)>, buffered_transmits: VecDeque>, - events: VecDeque, + events: VecDeque, sent_requests: BTreeMap< TransactionId, @@ -91,6 +91,12 @@ pub struct Allocation { credentials: Option, } +#[derive(Debug, PartialEq)] +pub(crate) enum Event { + New(Candidate), + Invalid(Candidate), +} + #[derive(Debug, Clone)] struct Credentials { username: Username, @@ -227,15 +233,10 @@ impl Allocation { allocation } - pub fn current_candidates(&self) -> impl Iterator { - [ - self.ip4_srflx_candidate.clone(), - self.ip6_srflx_candidate.clone(), - self.ip4_allocation.clone(), - self.ip6_allocation.clone(), - ] - .into_iter() - .flatten() + pub fn current_relay_candidates(&self) -> impl Iterator { + [self.ip4_allocation.clone(), self.ip6_allocation.clone()] + .into_iter() + .flatten() } /// Refresh this allocation. @@ -654,7 +655,7 @@ impl Allocation { // TODO: Clean up unused channels } - pub fn poll_event(&mut self) -> Option { + pub fn poll_event(&mut self) -> Option { self.events.pop_front() } @@ -827,11 +828,11 @@ impl Allocation { tracing::info!(active_socket = ?self.active_socket, "Invalidating allocation"); if let Some(candidate) = self.ip4_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) + self.events.push_back(Event::Invalid(candidate)) } if let Some(candidate) = self.ip6_allocation.take() { - self.events.push_back(CandidateEvent::Invalid(candidate)) + self.events.push_back(Event::Invalid(candidate)) } self.channel_bindings.clear(); @@ -1047,17 +1048,17 @@ fn authenticate(message: Message, credentials: &Credentials) -> Messa fn update_candidate( maybe_new: Option, maybe_current: &mut Option, - events: &mut VecDeque, + events: &mut VecDeque, ) { match (maybe_new, &maybe_current) { (Some(new), Some(current)) if &new != current => { - events.push_back(CandidateEvent::New(new.clone())); - events.push_back(CandidateEvent::Invalid(current.clone())); + events.push_back(Event::New(new.clone())); + events.push_back(Event::Invalid(current.clone())); *maybe_current = Some(new); } (Some(new), None) => { *maybe_current = Some(new.clone()); - events.push_back(CandidateEvent::New(new)); + events.push_back(Event::New(new)); } _ => {} } @@ -1926,14 +1927,14 @@ mod tests { let next_event = allocation.poll_event(); assert_eq!( next_event, - Some(CandidateEvent::New( + Some(Event::New( Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap() )) ); let next_event = allocation.poll_event(); assert_eq!( next_event, - Some(CandidateEvent::New( + Some(Event::New( Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap() )) ); @@ -1978,21 +1979,20 @@ mod tests { assert_eq!( allocation.poll_event(), - Some(CandidateEvent::Invalid( + Some(Event::Invalid( Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap() )) ); assert_eq!( allocation.poll_event(), - Some(CandidateEvent::Invalid( + Some(Event::Invalid( Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap() )) ); assert!(allocation.poll_event().is_none()); assert_eq!( - allocation.current_candidates().collect::>(), - vec![Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap()], - "server-reflexive candidate should still be valid after refresh" + allocation.current_relay_candidates().collect::>(), + vec![], ) } @@ -2310,8 +2310,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ) } @@ -2330,8 +2330,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ) } @@ -2362,8 +2362,8 @@ mod tests { assert_eq!( iter::from_fn(|| allocation.poll_event()).collect::>(), vec![ - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), - CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()), + Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()), ] ); assert_eq!( @@ -2451,10 +2451,10 @@ mod tests { assert_eq!( events, vec![ - CandidateEvent::New( + Event::New( Candidate::server_reflexive(PEER2_IP4, PEER2_IP4, Protocol::Udp).unwrap() ), - CandidateEvent::New( + Event::New( Candidate::server_reflexive(PEER2_IP6, PEER2_IP6, Protocol::Udp).unwrap() ) ] diff --git a/rust/connlib/snownet/src/candidate_set.rs b/rust/connlib/snownet/src/candidate_set.rs new file mode 100644 index 000000000..ad07e2eb9 --- /dev/null +++ b/rust/connlib/snownet/src/candidate_set.rs @@ -0,0 +1,28 @@ +use std::collections::HashSet; + +use itertools::Itertools; +use str0m::Candidate; + +/// Custom "set" implementation for [`Candidate`]s based on a [`HashSet`] with an enforced ordering when iterating. +#[derive(Debug, Default)] +pub struct CandidateSet { + inner: HashSet, +} + +impl CandidateSet { + pub fn insert(&mut self, c: Candidate) -> bool { + self.inner.insert(c) + } + + pub fn clear(&mut self) { + self.inner.clear() + } + + #[expect( + clippy::disallowed_methods, + reason = "We are guaranteeing a stable ordering" + )] + pub fn iter(&self) -> impl Iterator { + self.inner.iter().sorted_by_key(|c| c.prio()) + } +} diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index a0aa1dc12..aeeb1abbd 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -2,6 +2,7 @@ mod allocation; mod backoff; +mod candidate_set; mod channel_data; mod index; mod node; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 7f1a2f51a..35fe00421 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -1,4 +1,5 @@ -use crate::allocation::{Allocation, RelaySocket, Socket}; +use crate::allocation::{self, Allocation, RelaySocket, Socket}; +use crate::candidate_set::CandidateSet; use crate::index::IndexLfsr; use crate::ringbuffer::RingBuffer; use crate::stats::{ConnectionStats, NodeStats}; @@ -13,7 +14,6 @@ use hex_display::HexDisplayExt; use ip_packet::{ ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf, MAX_DATAGRAM_PAYLOAD, }; -use itertools::Itertools as _; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, Rng, SeedableRng}; @@ -114,7 +114,8 @@ pub struct Node { index: IndexLfsr, rate_limiter: Arc, - host_candidates: Vec, // `Candidate` doesn't implement `PartialOrd` so we cannot use a `BTreeSet`. Linear search is okay because we expect this vec to be <100 elements + /// Host and server-reflexive candidates that are shared between all connections. + shared_candidates: CandidateSet, buffered_transmits: VecDeque>, next_rate_limiter_reset: Option, @@ -168,7 +169,7 @@ where mode: T::new(), index: IndexLfsr::default(), rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)), - host_candidates: Default::default(), + shared_candidates: Default::default(), buffered_transmits: VecDeque::default(), next_rate_limiter_reset: None, pending_events: VecDeque::default(), @@ -205,7 +206,7 @@ where self.pending_events.extend(closed_connections); - self.host_candidates.clear(); + self.shared_candidates.clear(); self.connections.clear(); self.buffered_transmits.clear(); @@ -706,9 +707,7 @@ where agent.handle_timeout(now); if self.allocations.is_empty() { - tracing::warn!( - "No TURN servers connected; connection will very likely fail to establish" - ); + tracing::warn!("No TURN servers connected; connection may fail to establish"); } Connection { @@ -744,12 +743,10 @@ where fn add_local_as_host_candidate(&mut self, local: SocketAddr) -> Result<(), Error> { let host_candidate = Candidate::host(local, Protocol::Udp)?; - if self.host_candidates.contains(&host_candidate) { + if self.shared_candidates.insert(host_candidate.clone()) { return Ok(()); } - self.host_candidates.push(host_candidate.clone()); - for (cid, agent, _span) in self.connections.agents_mut() { add_local_candidate(cid, agent, host_candidate.clone(), &mut self.pending_events); } @@ -907,20 +904,18 @@ where tracing::trace!(%rid, ?event); match event { - CandidateEvent::New(candidate) + allocation::Event::New(candidate) if candidate.kind() == CandidateKind::ServerReflexive => { - for (cid, agent, _span) in self.connections.agents_mut() { - add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events) - } + self.shared_candidates.insert(candidate); } - CandidateEvent::New(candidate) => { + allocation::Event::New(candidate) => { for (cid, agent, _span) in self.connections.connecting_agents_by_relay_mut(rid) { add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events) } } - CandidateEvent::Invalid(candidate) => { + allocation::Event::Invalid(candidate) => { for (cid, agent, _span) in self.connections.agents_mut() { remove_local_candidate(cid, agent, &candidate, &mut self.pending_events); } @@ -1114,17 +1109,7 @@ where selected_relay: Option, agent: &mut IceAgent, ) { - for candidate in self.host_candidates.iter().cloned() { - add_local_candidate(connection, agent, candidate, &mut self.pending_events); - } - - for candidate in self - .allocations - .values() - .flat_map(|a| a.current_candidates()) - .filter(|c| c.kind() == CandidateKind::ServerReflexive) - .unique() - { + for candidate in self.shared_candidates.iter().cloned() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } @@ -1138,10 +1123,7 @@ where return; }; - for candidate in allocation - .current_candidates() - .filter(|c| c.kind() == CandidateKind::Relayed) - { + for candidate in allocation.current_relay_candidates() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } } @@ -1406,10 +1388,7 @@ fn invalidate_allocation_candidates( RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { for (cid, agent, _guard) in connections.agents_mut() { - for candidate in allocation - .current_candidates() - .filter(|c| c.kind() == CandidateKind::Relayed) - { + for candidate in allocation.current_relay_candidates() { remove_local_candidate(cid, agent, &candidate, pending_events); } } @@ -1576,12 +1555,6 @@ impl<'a> Transmit<'a> { } } -#[derive(Debug, PartialEq)] -pub(crate) enum CandidateEvent { - New(Candidate), - Invalid(Candidate), -} - struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>,