refactor(connlib): track srvflx candidates separately (#7163)

As part of maintaining an allocation, we also perform STUN with our
relays to discover our server-reflexive address. At the moment, these
candidates are scoped to an `Allocation`. This is unnecessarily
restrictive. Similar to host candidates, server-reflexive candidate
entirely depend on the socket you send data from and are thus
independent of the allocation's state.

During normal operation, this doesn't really matter because all relay
traffic is sent through the same sockets so all `Allocation`s end up
with the same server-reflexive candidates. Where this does matter is
when we disconnect from relay's for one reason or another (for example:
#7162). The fact that all but host-candidates are scoped to
`Allocation`s means that without `Allocation`s, we cannot make any new
connections, not even direct ones. This is unnecessarily restrictive and
causes bugs within `Allocation` to have a bigger blast radius than
necessary.

With this PR, we keep server-reflexive candidates in the same set as
host candidates. This allows us to at least establish direct connections
in case something is wrong with the relays or our state tracking of
relays on the client side.
This commit is contained in:
Thomas Eizinger
2024-10-29 03:57:41 +11:00
committed by GitHub
parent 1f7a0430b7
commit 046b9e0cd4
4 changed files with 77 additions and 75 deletions

View File

@@ -1,6 +1,6 @@
use crate::{
backoff::{self, ExponentialBackoff},
node::{CandidateEvent, SessionId, Transmit},
node::{SessionId, Transmit},
ringbuffer::RingBuffer,
utils::earliest,
EncryptedPacket,
@@ -70,7 +70,7 @@ pub struct Allocation {
allocation_lifetime: Option<(Instant, Duration)>,
buffered_transmits: VecDeque<Transmit<'static>>,
events: VecDeque<CandidateEvent>,
events: VecDeque<Event>,
sent_requests: BTreeMap<
TransactionId,
@@ -91,6 +91,12 @@ pub struct Allocation {
credentials: Option<Credentials>,
}
#[derive(Debug, PartialEq)]
pub(crate) enum Event {
New(Candidate),
Invalid(Candidate),
}
#[derive(Debug, Clone)]
struct Credentials {
username: Username,
@@ -227,15 +233,10 @@ impl Allocation {
allocation
}
pub fn current_candidates(&self) -> impl Iterator<Item = Candidate> {
[
self.ip4_srflx_candidate.clone(),
self.ip6_srflx_candidate.clone(),
self.ip4_allocation.clone(),
self.ip6_allocation.clone(),
]
.into_iter()
.flatten()
pub fn current_relay_candidates(&self) -> impl Iterator<Item = Candidate> {
[self.ip4_allocation.clone(), self.ip6_allocation.clone()]
.into_iter()
.flatten()
}
/// Refresh this allocation.
@@ -654,7 +655,7 @@ impl Allocation {
// TODO: Clean up unused channels
}
pub fn poll_event(&mut self) -> Option<CandidateEvent> {
pub fn poll_event(&mut self) -> Option<Event> {
self.events.pop_front()
}
@@ -827,11 +828,11 @@ impl Allocation {
tracing::info!(active_socket = ?self.active_socket, "Invalidating allocation");
if let Some(candidate) = self.ip4_allocation.take() {
self.events.push_back(CandidateEvent::Invalid(candidate))
self.events.push_back(Event::Invalid(candidate))
}
if let Some(candidate) = self.ip6_allocation.take() {
self.events.push_back(CandidateEvent::Invalid(candidate))
self.events.push_back(Event::Invalid(candidate))
}
self.channel_bindings.clear();
@@ -1047,17 +1048,17 @@ fn authenticate(message: Message<Attribute>, credentials: &Credentials) -> Messa
fn update_candidate(
maybe_new: Option<Candidate>,
maybe_current: &mut Option<Candidate>,
events: &mut VecDeque<CandidateEvent>,
events: &mut VecDeque<Event>,
) {
match (maybe_new, &maybe_current) {
(Some(new), Some(current)) if &new != current => {
events.push_back(CandidateEvent::New(new.clone()));
events.push_back(CandidateEvent::Invalid(current.clone()));
events.push_back(Event::New(new.clone()));
events.push_back(Event::Invalid(current.clone()));
*maybe_current = Some(new);
}
(Some(new), None) => {
*maybe_current = Some(new.clone());
events.push_back(CandidateEvent::New(new));
events.push_back(Event::New(new));
}
_ => {}
}
@@ -1926,14 +1927,14 @@ mod tests {
let next_event = allocation.poll_event();
assert_eq!(
next_event,
Some(CandidateEvent::New(
Some(Event::New(
Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap()
))
);
let next_event = allocation.poll_event();
assert_eq!(
next_event,
Some(CandidateEvent::New(
Some(Event::New(
Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()
))
);
@@ -1978,21 +1979,20 @@ mod tests {
assert_eq!(
allocation.poll_event(),
Some(CandidateEvent::Invalid(
Some(Event::Invalid(
Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()
))
);
assert_eq!(
allocation.poll_event(),
Some(CandidateEvent::Invalid(
Some(Event::Invalid(
Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()
))
);
assert!(allocation.poll_event().is_none());
assert_eq!(
allocation.current_candidates().collect::<Vec<_>>(),
vec![Candidate::server_reflexive(PEER1, PEER1, Protocol::Udp).unwrap()],
"server-reflexive candidate should still be valid after refresh"
allocation.current_relay_candidates().collect::<Vec<_>>(),
vec![],
)
}
@@ -2310,8 +2310,8 @@ mod tests {
assert_eq!(
iter::from_fn(|| allocation.poll_event()).collect::<Vec<_>>(),
vec![
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
]
)
}
@@ -2330,8 +2330,8 @@ mod tests {
assert_eq!(
iter::from_fn(|| allocation.poll_event()).collect::<Vec<_>>(),
vec![
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
]
)
}
@@ -2362,8 +2362,8 @@ mod tests {
assert_eq!(
iter::from_fn(|| allocation.poll_event()).collect::<Vec<_>>(),
vec![
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
CandidateEvent::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP4, Protocol::Udp).unwrap()),
Event::Invalid(Candidate::relayed(RELAY_ADDR_IP6, Protocol::Udp).unwrap()),
]
);
assert_eq!(
@@ -2451,10 +2451,10 @@ mod tests {
assert_eq!(
events,
vec![
CandidateEvent::New(
Event::New(
Candidate::server_reflexive(PEER2_IP4, PEER2_IP4, Protocol::Udp).unwrap()
),
CandidateEvent::New(
Event::New(
Candidate::server_reflexive(PEER2_IP6, PEER2_IP6, Protocol::Udp).unwrap()
)
]

View File

@@ -0,0 +1,28 @@
use std::collections::HashSet;
use itertools::Itertools;
use str0m::Candidate;
/// Custom "set" implementation for [`Candidate`]s based on a [`HashSet`] with an enforced ordering when iterating.
#[derive(Debug, Default)]
pub struct CandidateSet {
inner: HashSet<Candidate>,
}
impl CandidateSet {
pub fn insert(&mut self, c: Candidate) -> bool {
self.inner.insert(c)
}
pub fn clear(&mut self) {
self.inner.clear()
}
#[expect(
clippy::disallowed_methods,
reason = "We are guaranteeing a stable ordering"
)]
pub fn iter(&self) -> impl Iterator<Item = &Candidate> {
self.inner.iter().sorted_by_key(|c| c.prio())
}
}

View File

@@ -2,6 +2,7 @@
mod allocation;
mod backoff;
mod candidate_set;
mod channel_data;
mod index;
mod node;

View File

@@ -1,4 +1,5 @@
use crate::allocation::{Allocation, RelaySocket, Socket};
use crate::allocation::{self, Allocation, RelaySocket, Socket};
use crate::candidate_set::CandidateSet;
use crate::index::IndexLfsr;
use crate::ringbuffer::RingBuffer;
use crate::stats::{ConnectionStats, NodeStats};
@@ -13,7 +14,6 @@ use hex_display::HexDisplayExt;
use ip_packet::{
ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf, MAX_DATAGRAM_PAYLOAD,
};
use itertools::Itertools as _;
use rand::rngs::StdRng;
use rand::seq::IteratorRandom;
use rand::{random, Rng, SeedableRng};
@@ -114,7 +114,8 @@ pub struct Node<T, TId, RId> {
index: IndexLfsr,
rate_limiter: Arc<RateLimiter>,
host_candidates: Vec<Candidate>, // `Candidate` doesn't implement `PartialOrd` so we cannot use a `BTreeSet`. Linear search is okay because we expect this vec to be <100 elements
/// Host and server-reflexive candidates that are shared between all connections.
shared_candidates: CandidateSet,
buffered_transmits: VecDeque<Transmit<'static>>,
next_rate_limiter_reset: Option<Instant>,
@@ -168,7 +169,7 @@ where
mode: T::new(),
index: IndexLfsr::default(),
rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)),
host_candidates: Default::default(),
shared_candidates: Default::default(),
buffered_transmits: VecDeque::default(),
next_rate_limiter_reset: None,
pending_events: VecDeque::default(),
@@ -205,7 +206,7 @@ where
self.pending_events.extend(closed_connections);
self.host_candidates.clear();
self.shared_candidates.clear();
self.connections.clear();
self.buffered_transmits.clear();
@@ -706,9 +707,7 @@ where
agent.handle_timeout(now);
if self.allocations.is_empty() {
tracing::warn!(
"No TURN servers connected; connection will very likely fail to establish"
);
tracing::warn!("No TURN servers connected; connection may fail to establish");
}
Connection {
@@ -744,12 +743,10 @@ where
fn add_local_as_host_candidate(&mut self, local: SocketAddr) -> Result<(), Error> {
let host_candidate = Candidate::host(local, Protocol::Udp)?;
if self.host_candidates.contains(&host_candidate) {
if self.shared_candidates.insert(host_candidate.clone()) {
return Ok(());
}
self.host_candidates.push(host_candidate.clone());
for (cid, agent, _span) in self.connections.agents_mut() {
add_local_candidate(cid, agent, host_candidate.clone(), &mut self.pending_events);
}
@@ -907,20 +904,18 @@ where
tracing::trace!(%rid, ?event);
match event {
CandidateEvent::New(candidate)
allocation::Event::New(candidate)
if candidate.kind() == CandidateKind::ServerReflexive =>
{
for (cid, agent, _span) in self.connections.agents_mut() {
add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events)
}
self.shared_candidates.insert(candidate);
}
CandidateEvent::New(candidate) => {
allocation::Event::New(candidate) => {
for (cid, agent, _span) in self.connections.connecting_agents_by_relay_mut(rid)
{
add_local_candidate(cid, agent, candidate.clone(), &mut self.pending_events)
}
}
CandidateEvent::Invalid(candidate) => {
allocation::Event::Invalid(candidate) => {
for (cid, agent, _span) in self.connections.agents_mut() {
remove_local_candidate(cid, agent, &candidate, &mut self.pending_events);
}
@@ -1114,17 +1109,7 @@ where
selected_relay: Option<RId>,
agent: &mut IceAgent,
) {
for candidate in self.host_candidates.iter().cloned() {
add_local_candidate(connection, agent, candidate, &mut self.pending_events);
}
for candidate in self
.allocations
.values()
.flat_map(|a| a.current_candidates())
.filter(|c| c.kind() == CandidateKind::ServerReflexive)
.unique()
{
for candidate in self.shared_candidates.iter().cloned() {
add_local_candidate(connection, agent, candidate, &mut self.pending_events);
}
@@ -1138,10 +1123,7 @@ where
return;
};
for candidate in allocation
.current_candidates()
.filter(|c| c.kind() == CandidateKind::Relayed)
{
for candidate in allocation.current_relay_candidates() {
add_local_candidate(connection, agent, candidate, &mut self.pending_events);
}
}
@@ -1406,10 +1388,7 @@ fn invalidate_allocation_candidates<TId, RId>(
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
{
for (cid, agent, _guard) in connections.agents_mut() {
for candidate in allocation
.current_candidates()
.filter(|c| c.kind() == CandidateKind::Relayed)
{
for candidate in allocation.current_relay_candidates() {
remove_local_candidate(cid, agent, &candidate, pending_events);
}
}
@@ -1576,12 +1555,6 @@ impl<'a> Transmit<'a> {
}
}
#[derive(Debug, PartialEq)]
pub(crate) enum CandidateEvent {
New(Candidate),
Invalid(Candidate),
}
struct InitialConnection<RId> {
agent: IceAgent,
session_key: Secret<[u8; 32]>,