diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index ea5cb8507..2c75f61e4 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -30,6 +30,7 @@ use std::borrow::Cow; use std::iter; use std::ops::ControlFlow; use stun_codec::rfc5389::attributes::{Realm, Username}; +use tracing::{field, Span}; // Note: Taken from boringtun const HANDSHAKE_RATE_LIMIT: u64 = 100; @@ -372,10 +373,7 @@ where .. } => { let candidate = conn - .agent - .local_candidates() - .iter() - .find(|c| c.addr() == source) + .local_candidate(source) .expect("to only nominate existing candidates"); let remote_socket = match candidate.kind() { @@ -415,6 +413,8 @@ where tracing::info!(old = ?conn.peer_socket, new = ?remote_socket, "Updating remote socket"); conn.peer_socket = Some(remote_socket); + conn.invalidate_candiates(); + if is_first_connection { tracing::info!(%id, "Starting wireguard handshake"); @@ -1292,6 +1292,15 @@ enum PeerSocket { }, } +impl PeerSocket { + fn our_socket(&self) -> SocketAddr { + match self { + PeerSocket::Direct { source, .. } => *source, + PeerSocket::Relay { relay, .. } => *relay, + } + } +} + impl Connection { /// Checks if we want to accept a packet from a certain address. /// @@ -1464,4 +1473,40 @@ impl Connection { self.encapsulate(bytes, allocations, now) } + + /// Invalidates all local candidates with a lower or equal priority compared to the nominated one. + /// + /// Each time we nominate a candidate pair, we don't really want to keep all the others active because it creates a lot of noise. + /// At the same time, we want to retain trickle ICE and allow the ICE agent to find a _better_ pair, hence we invalidate by priority. + #[tracing::instrument(level = "debug", skip_all, fields(nominated_prio))] + fn invalidate_candiates(&mut self) { + let Some(socket) = self.peer_socket else { + return; + }; + + let Some(nominated) = self.local_candidate(socket.our_socket()).cloned() else { + return; + }; + + Span::current().record("nominated_prio", field::display(&nominated.prio())); + + let irrelevant_candidates = self + .agent + .local_candidates() + .iter() + .filter(|c| c.prio() <= nominated.prio() && c != &&nominated) + .cloned() + .collect::>(); + + for candidate in irrelevant_candidates { + self.agent.invalidate_candidate(&candidate); + } + } + + fn local_candidate(&self, source: SocketAddr) -> Option<&Candidate> { + self.agent + .local_candidates() + .iter() + .find(|c| c.addr() == source) + } }