diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 98919a066..f12af732d 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -27,7 +27,6 @@ use crate::utils::earliest; use crate::{IpPacket, MutableIpPacket}; use boringtun::noise::errors::WireGuardError; use std::borrow::Cow; -use std::iter; use std::ops::ControlFlow; use stun_codec::rfc5389::attributes::{Realm, Username}; use tracing::{field, info_span, Span}; @@ -212,7 +211,7 @@ where CandidateKind::Relayed => { // Optimisatically try to bind the channel only on the same relay as the remote peer. - if let Some(allocation) = self.same_relay_as_peer(id, &candidate) { + if let Some(allocation) = self.same_relay_as_peer(&candidate) { allocation.bind_channel(candidate.addr(), now); return; } @@ -221,11 +220,7 @@ where } // In other cases, bind on all relays. - for relay in self.connections.allowed_turn_servers(&id) { - let Some(allocation) = self.allocations.get_mut(relay) else { - continue; - }; - + for allocation in self.allocations.values_mut() { allocation.bind_channel(candidate.addr(), now); } } @@ -235,20 +230,13 @@ where /// To do that, we need to check all candidates of each allocation and compare their IP. /// The same relay might be reachable over IPv4 and IPv6. #[must_use] - fn same_relay_as_peer(&mut self, id: TId, candidate: &Candidate) -> Option<&mut Allocation> { - self.allocations - .iter_mut() - .filter(|(relay, _)| { - self.connections - .allowed_turn_servers(&id) - .any(|allowed| allowed == *relay) - }) - .find_map(|(_, allocation)| { - allocation - .current_candidates() - .any(|c| c.addr().ip() == candidate.addr().ip()) - .then_some(allocation) - }) + fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> { + self.allocations.iter_mut().find_map(|(_, allocation)| { + allocation + .current_candidates() + .any(|c| c.addr().ip() == candidate.addr().ip()) + .then_some(allocation) + }) } /// Decapsulate an incoming packet. @@ -459,8 +447,6 @@ where mut agent: IceAgent, remote: PublicKey, key: [u8; 32], - allowed_stun_servers: HashSet, - allowed_turn_servers: HashSet, intent_sent_at: Instant, now: Instant, ) -> Connection { @@ -481,8 +467,6 @@ where self.index.next(), Some(self.rate_limiter.clone()), ), - stun_servers: allowed_stun_servers, - turn_servers: allowed_turn_servers, next_timer_update: now, peer_socket: None, possible_sockets: Default::default(), @@ -663,21 +647,19 @@ where } fn bindings_and_allocations_drain_events(&mut self) { - let binding_events = self.bindings.iter_mut().flat_map(|(server, binding)| { - iter::from_fn(|| binding.poll_event().map(|e| (*server, e))) - }); + let binding_events = self + .bindings + .values_mut() + .flat_map(|binding| binding.poll_event()); let allocation_events = self .allocations - .iter_mut() - .flat_map(|(server, allocation)| { - iter::from_fn(|| allocation.poll_event().map(|e| (*server, e))) - }); + .values_mut() + .flat_map(|allocation| allocation.poll_event()); - for (server, event) in binding_events.chain(allocation_events) { + for event in binding_events.chain(allocation_events) { match event { CandidateEvent::New(candidate) => { add_local_candidate_to_all( - server, candidate, &mut self.connections, &mut self.pending_events, @@ -707,8 +689,8 @@ where pub fn new_connection( &mut self, id: TId, - allowed_stun_servers: HashSet, - allowed_turn_servers: HashSet<(SocketAddr, String, String, String)>, + stun_servers: HashSet, + turn_servers: HashSet<(SocketAddr, String, String, String)>, intent_sent_at: Instant, now: Instant, ) -> Offer { @@ -720,14 +702,8 @@ where tracing::info!("Replacing existing established connection"); }; - self.upsert_stun_servers(&allowed_stun_servers, now); - self.upsert_turn_servers(&allowed_turn_servers, now); - - let allowed_turn_servers = allowed_turn_servers - .iter() - .map(|(server, _, _, _)| server) - .copied() - .collect(); + self.upsert_stun_servers(&stun_servers, now); + self.upsert_turn_servers(&turn_servers, now); let mut agent = IceAgent::new(); agent.set_controlling(true); @@ -747,8 +723,6 @@ where let initial_connection = InitialConnection { agent, session_key, - stun_servers: allowed_stun_servers, - turn_servers: allowed_turn_servers, created_at: now, intent_sent_at, is_failed: false, @@ -782,19 +756,12 @@ where pass: answer.credentials.password, }); - self.seed_agent_with_local_candidates( - id, - &mut agent, - &initial.stun_servers, - &initial.turn_servers, - ); + self.seed_agent_with_local_candidates(id, &mut agent); let connection = self.init_connection( agent, remote, *initial.session_key.expose_secret(), - initial.stun_servers, - initial.turn_servers, initial.intent_sent_at, now, ); @@ -823,8 +790,8 @@ where id: TId, offer: Offer, remote: PublicKey, - allowed_stun_servers: HashSet, - allowed_turn_servers: HashSet<(SocketAddr, String, String, String)>, + stun_servers: HashSet, + turn_servers: HashSet<(SocketAddr, String, String, String)>, now: Instant, ) -> Answer { debug_assert!( @@ -836,14 +803,8 @@ where tracing::info!("Replacing existing established connection"); }; - self.upsert_stun_servers(&allowed_stun_servers, now); - self.upsert_turn_servers(&allowed_turn_servers, now); - - let allowed_turn_servers = allowed_turn_servers - .iter() - .map(|(server, _, _, _)| server) - .copied() - .collect(); + self.upsert_stun_servers(&stun_servers, now); + self.upsert_turn_servers(&turn_servers, now); let mut agent = IceAgent::new(); agent.set_controlling(false); @@ -858,19 +819,12 @@ where }, }; - self.seed_agent_with_local_candidates( - id, - &mut agent, - &allowed_stun_servers, - &allowed_turn_servers, - ); + self.seed_agent_with_local_candidates(id, &mut agent); let connection = self.init_connection( agent, remote, *offer.session_key.expose_secret(), - allowed_stun_servers, - allowed_turn_servers, now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time. now, ); @@ -928,24 +882,16 @@ where } } - fn seed_agent_with_local_candidates( - &mut self, - connection: TId, - agent: &mut IceAgent, - allowed_stun_servers: &HashSet, - allowed_turn_servers: &HashSet, - ) { + fn seed_agent_with_local_candidates(&mut self, connection: TId, agent: &mut IceAgent) { for candidate in self.host_candidates.iter().cloned() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } - for candidate in self.bindings.iter().filter_map(|(server, binding)| { - let candidate = allowed_stun_servers - .contains(server) - .then(|| binding.candidate())??; - - Some(candidate) - }) { + for candidate in self + .bindings + .values() + .filter_map(|binding| binding.candidate()) + { add_local_candidate( connection, agent, @@ -956,13 +902,8 @@ where for candidate in self .allocations - .iter() - .flat_map(|(server, allocation)| { - allowed_turn_servers - .contains(server) - .then(|| allocation.current_candidates()) - }) - .flatten() + .values() + .flat_map(|allocation| allocation.current_candidates()) { add_local_candidate( connection, @@ -1037,21 +978,6 @@ where self.established.get_mut(id) } - fn allowed_turn_servers(&self, id: &TId) -> impl Iterator + '_ { - let initial = self - .initial - .get(id) - .into_iter() - .flat_map(|c| c.turn_servers.iter()); - let established = self - .established - .get(id) - .into_iter() - .flat_map(|c| c.turn_servers.iter()); - - initial.chain(established) - } - fn iter_established(&self) -> impl Iterator { self.established.iter().map(|(id, conn)| (*id, conn)) } @@ -1098,7 +1024,6 @@ enum EncodeError { } fn add_local_candidate_to_all( - server: SocketAddr, candidate: Candidate, connections: &mut Connections, pending_events: &mut VecDeque>, @@ -1108,34 +1033,15 @@ fn add_local_candidate_to_all( let initial_connections = connections .initial .iter_mut() - .map(|(id, c)| (*id, &c.stun_servers, &c.turn_servers, &mut c.agent)); + .map(|(id, c)| (*id, &mut c.agent)); let established_connections = connections .established .iter_mut() - .map(|(id, c)| (*id, &c.stun_servers, &c.turn_servers, &mut c.agent)); + .map(|(id, c)| (*id, &mut c.agent)); - for (id, allowed_stun, allowed_turn, agent) in - initial_connections.chain(established_connections) - { + for (id, agent) in initial_connections.chain(established_connections) { let _span = info_span!("connection", %id).entered(); - match candidate.kind() { - CandidateKind::ServerReflexive => { - if (!allowed_stun.contains(&server)) && (!allowed_turn.contains(&server)) { - tracing::debug!(%server, ?allowed_stun, ?allowed_turn, "Not adding srflx candidate"); - continue; - } - } - CandidateKind::Relayed => { - if !allowed_turn.contains(&server) { - tracing::debug!(%server, ?allowed_turn, "Not adding relay candidate"); - - continue; - } - } - CandidateKind::PeerReflexive | CandidateKind::Host => continue, - } - add_local_candidate(id, agent, candidate.clone(), pending_events); } } @@ -1226,8 +1132,6 @@ pub(crate) enum CandidateEvent { struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>, - stun_servers: HashSet, - turn_servers: HashSet, created_at: Instant, intent_sent_at: Instant, @@ -1265,9 +1169,6 @@ struct Connection { // Socket addresses from which we might receive data (even before we are connected). possible_sockets: HashSet, - stun_servers: HashSet, - turn_servers: HashSet, - stats: ConnectionStats, buffer: Box<[u8; MAX_UDP_SIZE]>,