diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index b52dd431b..000b80d77 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -1,9 +1,11 @@ +mod allocations; mod connections; pub use connections::UnknownConnection; use crate::allocation::{self, Allocation, RelaySocket, Socket}; use crate::index::IndexLfsr; +use crate::node::allocations::Allocations; use crate::node::connections::Connections; use crate::stats::{ConnectionStats, NodeStats}; use crate::utils::channel_data_packet_buffer; @@ -20,12 +22,10 @@ use hex_display::HexDisplayExt; use ip_packet::{Ecn, IpPacket, IpPacketBuf}; use itertools::Itertools; use rand::rngs::StdRng; -use rand::seq::IteratorRandom; use rand::{RngCore, SeedableRng}; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use sha2::Digest; -use std::collections::btree_map::Entry; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::hash::Hash; use std::net::IpAddr; use std::ops::ControlFlow; @@ -130,7 +130,7 @@ pub struct Node { next_rate_limiter_reset: Option, - allocations: BTreeMap, + allocations: Allocations, connections: Connections, pending_events: VecDeque>, @@ -426,7 +426,7 @@ where | CandidateKind::PeerReflexive => {} } - let Some(allocation) = self.allocations.get_mut(&relay) else { + let Some(allocation) = self.allocations.get_mut_by_id(&relay) else { tracing::debug!(rid = %relay, "Unknown relay"); return; }; @@ -547,11 +547,7 @@ where pub fn poll_timeout(&mut self) -> Option<(Instant, &'static str)> { iter::empty() .chain(self.connections.poll_timeout()) - .chain( - self.allocations - .values_mut() - .filter_map(|a| a.poll_timeout()), - ) + .chain(self.allocations.poll_timeout()) .chain( self.next_rate_limiter_reset .map(|instant| (instant, "rate limiter reset")), @@ -576,9 +572,7 @@ where /// /// As such, it ends up being cleaner to "drain" all lower-level components of their events, transmits etc within this function. pub fn handle_timeout(&mut self, now: Instant) { - for allocation in self.allocations.values_mut() { - allocation.handle_timeout(now); - } + self.allocations.handle_timeout(now); self.allocations_drain_events(); @@ -602,15 +596,7 @@ where } } - self.allocations - .retain(|rid, allocation| match allocation.can_be_freed() { - Some(e) => { - tracing::info!(%rid, "Disconnecting from relay; {e}"); - - false - } - None => true, - }); + self.allocations.gc(); self.connections.check_relays_available( &self.allocations, &mut self.pending_events, @@ -623,12 +609,7 @@ where /// Returns buffered data that needs to be sent on the socket. #[must_use] pub fn poll_transmit(&mut self) -> Option { - let allocation_transmits = &mut self - .allocations - .values_mut() - .flat_map(Allocation::poll_transmit); - - if let Some(transmit) = allocation_transmits.next() { + if let Some(transmit) = self.allocations.poll_transmit() { self.stats.stun_bytes_to_relays += transmit.payload.len(); tracing::trace!(?transmit); @@ -650,7 +631,7 @@ where ) { // First, invalidate all candidates from relays that we should stop using. for rid in &to_remove { - let Some(allocation) = self.allocations.remove(rid) else { + let Some(allocation) = self.allocations.remove_by_id(rid) else { tracing::debug!(%rid, "Cannot delete unknown allocation"); continue; @@ -676,47 +657,29 @@ where continue; }; - match self.allocations.entry(*rid) { - Entry::Vacant(v) => { - v.insert(Allocation::new( - *server, - username, - password.clone(), - realm, - now, - self.session_id.clone(), - self.buffer_pool.clone(), - )); - - tracing::info!(%rid, address = ?server, "Added new TURN server"); + match self.allocations.upsert( + *rid, + *server, + username, + password.clone(), + realm, + now, + self.session_id.clone(), + ) { + allocations::UpsertResult::Added => { + tracing::info!(%rid, address = ?server, "Added new TURN server") } - Entry::Occupied(mut o) => { - let allocation = o.get(); - - if allocation.matches_credentials(&username, password) - && allocation.matches_socket(server) - { - tracing::info!(%rid, address = ?server, "Skipping known TURN server"); - continue; - } - + allocations::UpsertResult::Skipped => { + tracing::info!(%rid, address = ?server, "Skipping known TURN server") + } + allocations::UpsertResult::Replaced(previous) => { invalidate_allocation_candidates( &mut self.connections, - allocation, + &previous, &mut self.pending_events, ); - o.insert(Allocation::new( - *server, - username, - password.clone(), - realm, - now, - self.session_id.clone(), - self.buffer_pool.clone(), - )); - - tracing::info!(%rid, address = ?server, "Replaced TURN server"); + tracing::info!(%rid, address = ?server, "Replaced TURN server") } } } @@ -829,7 +792,12 @@ where // The above check would wrongly classify a STUN request from such a peer as relay traffic and // fail to process it because we don't have an `Allocation` for the peer's IP. // - // Effectively this means that the connection will have to fallback to a relay-relay candidate pair. + // At the same time, we may still receive packets on port 3478 for an allocation that we have discarded. + // + // To correctly handle these packets, we need to handle them differently, depending on whether we + // previously had an allocation on a certain relay: + // 1. If we previously had an allocation, we need to stop processing the packet. + // 2. If we don't recognize the IP, continue processing the packet (as it may be p2p traffic). return ControlFlow::Continue((from, packet, None)); } @@ -841,18 +809,20 @@ where return ControlFlow::Continue((from, packet, None)); }; - let Some(allocation) = self - .allocations - .values_mut() - .find(|a| a.server().matches(from)) - else { - tracing::debug!( - %from, - packet = %hex::encode(packet), - "Packet was a STUN message but we are not connected to this relay" - ); + let allocation = match self.allocations.get_mut_by_server(from) { + allocations::MutAllocationRef::Connected(_, allocation) => allocation, + allocations::MutAllocationRef::Disconnected => { + tracing::debug!( + %from, + packet = %hex::encode(packet), + "Packet was a STUN message but we are no longer connected to this relay" + ); - return ControlFlow::Break(()); // Stop processing the packet. + return ControlFlow::Break(()); // Stop processing the packet. + } + allocations::MutAllocationRef::Unknown => { + return ControlFlow::Continue((from, packet, None)); + } }; if allocation.handle_input(from, local, message, now) { @@ -871,14 +841,19 @@ where return ControlFlow::Continue((from, packet, None)); }; - let Some(allocation) = self - .allocations - .values_mut() - .find(|a| a.server().matches(from)) - else { - tracing::debug!("Packet was a channel data message for unknown allocation"); + let allocation = match self.allocations.get_mut_by_server(from) { + allocations::MutAllocationRef::Connected(_, allocation) => allocation, + allocations::MutAllocationRef::Disconnected => { + tracing::debug!( + %from, + "Packet was a channel-data message but we are no longer connected to this relay" + ); - return ControlFlow::Break(()); // Stop processing the packet. + return ControlFlow::Break(()); // Stop processing the packet. + } + allocations::MutAllocationRef::Unknown => { + return ControlFlow::Continue((from, packet, None)); + } }; let Some((from, packet, socket)) = allocation.decapsulate(from, cd, now) else { @@ -1005,11 +980,7 @@ where } fn allocations_drain_events(&mut self) { - let allocation_events = self.allocations.iter_mut().flat_map(|(rid, allocation)| { - std::iter::from_fn(|| allocation.poll_event()).map(|e| (*rid, e)) - }); - - for (rid, event) in allocation_events { + while let Some((rid, event)) = self.allocations.poll_event() { tracing::trace!(%rid, ?event); match event { @@ -1029,11 +1000,9 @@ where /// Sample a relay to use for a new connection. fn sample_relay(&mut self) -> Result { - let rid = self + let (rid, _) = self .allocations - .keys() - .copied() - .choose(&mut self.rng) + .sample(&mut self.rng) .ok_or(NoTurnServers {})?; tracing::debug!(%rid, "Sampled relay"); @@ -1225,18 +1194,15 @@ fn seed_agent_with_local_candidates( connection: TId, selected_relay: RId, agent: &mut IceAgent, - allocations: &BTreeMap, + allocations: &Allocations, pending_events: &mut VecDeque>, ) where - RId: Ord, + RId: Ord + fmt::Display + Copy, TId: fmt::Display + Copy, { - let shared_candidates = allocations - .values() - .flat_map(|allocation| allocation.host_and_server_reflexive_candidates()) - .unique(); + let shared_candidates = allocations.shared_candidates(); let relay_candidates = allocations - .get(&selected_relay) + .get_by_id(&selected_relay) .into_iter() .flat_map(|allocation| allocation.current_relay_candidates()); @@ -1846,7 +1812,7 @@ where &mut self, cid: TId, now: Instant, - allocations: &mut BTreeMap, + allocations: &mut Allocations, transmits: &mut VecDeque, ) where TId: Copy + Ord + fmt::Display, @@ -1913,9 +1879,7 @@ where source, .. } => { - let source_relay = allocations.iter().find_map(|(relay, allocation)| { - allocation.has_socket(source).then_some(*relay) - }); + let source_relay = allocations.get_mut_by_allocation(source).map(|(r, _)| r); if source_relay.is_some_and(|r| self.relay.id != r) { tracing::warn!( @@ -2042,11 +2006,7 @@ where let stun_packet = transmit.contents; // Check if `str0m` wants us to send from a "remote" socket, i.e. one that we allocated with a relay. - let allocation = allocations - .iter_mut() - .find(|(_, allocation)| allocation.has_socket(source)); - - let Some((relay, allocation)) = allocation else { + let Some((relay, allocation)) = allocations.get_mut_by_allocation(source) else { self.stats.stun_bytes_to_peer_direct += stun_packet.len(); // `source` did not match any of our allocated sockets, must be a local one then! @@ -2084,7 +2044,7 @@ where fn handle_tunnel_timeout( &mut self, now: Instant, - allocations: &mut BTreeMap, + allocations: &mut Allocations, transmits: &mut VecDeque, ) { // Don't update wireguard timers until we are connected. @@ -2130,7 +2090,7 @@ where socket: PeerSocket, packet: &IpPacket, now: Instant, - allocations: &mut BTreeMap, + allocations: &mut Allocations, ) -> Result> where TId: fmt::Display, @@ -2173,7 +2133,7 @@ where ecn: packet.ecn(), })), PeerSocket::RelayToPeer { dest: peer } | PeerSocket::RelayToRelay { dest: peer } => { - let Some(allocation) = allocations.get_mut(&self.relay.id) else { + let Some(allocation) = allocations.get_mut_by_id(&self.relay.id) else { tracing::warn!(relay = %self.relay.id, "No allocation"); return Ok(None); }; @@ -2200,7 +2160,7 @@ where cid: TId, src: IpAddr, packet: &[u8], - allocations: &mut BTreeMap, + allocations: &mut Allocations, transmits: &mut VecDeque, now: Instant, ) -> ControlFlow, IpPacket> @@ -2307,7 +2267,7 @@ where fn initiate_wg_session( &mut self, - allocations: &mut BTreeMap, + allocations: &mut Allocations, transmits: &mut VecDeque, now: Instant, ) where @@ -2382,11 +2342,11 @@ fn make_owned_transmit( socket: PeerSocket, message: &[u8], buffer_pool: &BufferPool>, - allocations: &mut BTreeMap, + allocations: &mut Allocations, now: Instant, ) -> Option where - RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug, + RId: Ord + fmt::Display + Copy, { let transmit = match socket { PeerSocket::PeerToPeer { @@ -2403,7 +2363,7 @@ where ecn: Ecn::NonEct, }, PeerSocket::RelayToPeer { dest: peer } | PeerSocket::RelayToRelay { dest: peer } => { - let allocation = allocations.get_mut(&relay)?; + let allocation = allocations.get_mut_by_id(&relay)?; let mut channel_data = channel_data_packet_buffer(message); let encode_ok = allocation.encode_channel_data_header(peer, &mut channel_data, now)?; diff --git a/rust/connlib/snownet/src/node/allocations.rs b/rust/connlib/snownet/src/node/allocations.rs new file mode 100644 index 000000000..0704f39e6 --- /dev/null +++ b/rust/connlib/snownet/src/node/allocations.rs @@ -0,0 +1,320 @@ +use std::{ + collections::{BTreeMap, btree_map::Entry}, + fmt, + net::{IpAddr, SocketAddr}, + time::Instant, +}; + +use bufferpool::BufferPool; +use itertools::Itertools as _; +use rand::{Rng, seq::IteratorRandom as _}; +use ringbuffer::{AllocRingBuffer, RingBuffer}; +use str0m::Candidate; +use stun_codec::rfc5389::attributes::{Realm, Username}; + +use crate::{ + RelaySocket, Transmit, + allocation::{self, Allocation}, + node::SessionId, +}; + +pub(crate) struct Allocations { + inner: BTreeMap, + previous_relays_by_ip: AllocRingBuffer, + + buffer_pool: BufferPool>, +} + +impl Allocations +where + RId: Ord + fmt::Display + Copy, +{ + pub(crate) fn clear(&mut self) { + for (_, allocation) in std::mem::take(&mut self.inner) { + self.previous_relays_by_ip + .extend(server_addresses(&allocation)); + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub(crate) fn contains(&self, id: &RId) -> bool { + self.inner.contains_key(id) + } + + pub(crate) fn get_by_id(&self, id: &RId) -> Option<&Allocation> { + self.inner.get(id) + } + + pub(crate) fn get_mut_by_id(&mut self, id: &RId) -> Option<&mut Allocation> { + self.inner.get_mut(id) + } + + pub(crate) fn get_mut_by_allocation( + &mut self, + addr: SocketAddr, + ) -> Option<(RId, &mut Allocation)> { + self.inner + .iter_mut() + .find_map(|(id, a)| a.has_socket(addr).then_some((*id, a))) + } + + pub(crate) fn get_mut_by_server(&mut self, socket: SocketAddr) -> MutAllocationRef<'_, RId> { + self.inner + .iter_mut() + .find_map(|(id, a)| a.server().matches(socket).then_some((*id, a))) + .map(|(id, a)| MutAllocationRef::Connected(id, a)) + .or_else(|| { + self.previous_relays_by_ip + .contains(&socket.ip()) + .then_some(MutAllocationRef::Disconnected) + }) + .unwrap_or(MutAllocationRef::Unknown) + } + + pub(crate) fn iter_mut(&mut self) -> impl Iterator { + self.inner.iter_mut() + } + + pub(crate) fn remove_by_id(&mut self, id: &RId) -> Option { + let allocation = self.inner.remove(id)?; + + self.previous_relays_by_ip + .extend(server_addresses(&allocation)); + + Some(allocation) + } + + pub(crate) fn upsert( + &mut self, + rid: RId, + server: RelaySocket, + username: Username, + password: String, + realm: Realm, + now: Instant, + session_id: SessionId, + ) -> UpsertResult { + match self.inner.entry(rid) { + Entry::Vacant(v) => { + v.insert(Allocation::new( + server, + username, + password, + realm, + now, + session_id, + self.buffer_pool.clone(), + )); + + UpsertResult::Added + } + Entry::Occupied(mut o) => { + let allocation = o.get(); + + if allocation.matches_credentials(&username, &password) + && allocation.matches_socket(&server) + { + return UpsertResult::Skipped; + } + + let previous = o.insert(Allocation::new( + server, + username, + password, + realm, + now, + session_id, + self.buffer_pool.clone(), + )); + + self.previous_relays_by_ip + .extend(server_addresses(&previous)); + + UpsertResult::Replaced(previous) + } + } + } + + pub(crate) fn sample(&self, rng: &mut impl Rng) -> Option<(RId, &Allocation)> { + let (id, a) = self.inner.iter().choose(rng)?; + + Some((*id, a)) + } + + pub(crate) fn shared_candidates(&self) -> impl Iterator { + self.inner + .values() + .flat_map(|allocation| allocation.host_and_server_reflexive_candidates()) + .unique() + } + + pub(crate) fn poll_timeout(&mut self) -> Option<(Instant, &'static str)> { + self.inner + .values_mut() + .filter_map(|a| a.poll_timeout()) + .min_by_key(|(t, _)| *t) + } + + pub(crate) fn poll_event(&mut self) -> Option<(RId, allocation::Event)> { + self.inner + .iter_mut() + .filter_map(|(id, a)| Some((*id, a.poll_event()?))) + .next() + } + + pub(crate) fn handle_timeout(&mut self, now: Instant) { + for allocation in self.inner.values_mut() { + allocation.handle_timeout(now); + } + } + + pub(crate) fn poll_transmit(&mut self) -> Option { + self.inner + .values_mut() + .filter_map(Allocation::poll_transmit) + .next() + } + + pub(crate) fn gc(&mut self) { + self.inner + .retain(|rid, allocation| match allocation.can_be_freed() { + Some(e) => { + tracing::info!(%rid, "Disconnecting from relay; {e}"); + + self.previous_relays_by_ip + .extend(server_addresses(allocation)); + + false + } + None => true, + }); + } +} + +pub(crate) enum MutAllocationRef<'a, RId> { + Unknown, + Disconnected, + Connected(RId, &'a mut Allocation), +} + +fn server_addresses(allocation: &Allocation) -> impl Iterator { + std::iter::empty() + .chain( + allocation + .server() + .as_v4() + .map(|s| s.ip()) + .copied() + .map(IpAddr::from), + ) + .chain( + allocation + .server() + .as_v6() + .map(|s| s.ip()) + .copied() + .map(IpAddr::from), + ) +} + +pub(crate) enum UpsertResult { + Added, + Skipped, + Replaced(Allocation), +} + +impl Default for Allocations { + fn default() -> Self { + Self { + inner: Default::default(), + previous_relays_by_ip: AllocRingBuffer::with_capacity_power_of_2(6), // 64 entries, + buffer_pool: BufferPool::new(ip_packet::MAX_FZ_PAYLOAD, "turn-clients"), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddrV4}; + + use boringtun::x25519::PublicKey; + + use super::*; + + #[test] + fn manual_remove_remembers_address() { + let mut allocations = Allocations::default(); + allocations.upsert( + 1, + RelaySocket::from(SERVER_V4), + Username::new("test".to_owned()).unwrap(), + "password".to_owned(), + Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), + SessionId::new(PublicKey::from([0u8; 32])), + ); + + allocations.remove_by_id(&1); + + assert!(matches!( + allocations.get_mut_by_server(SERVER_V4), + MutAllocationRef::Disconnected + )); + } + + #[test] + fn clear_remembers_address() { + let mut allocations = Allocations::default(); + allocations.upsert( + 1, + RelaySocket::from(SERVER_V4), + Username::new("test".to_owned()).unwrap(), + "password".to_owned(), + Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), + SessionId::new(PublicKey::from([0u8; 32])), + ); + + allocations.clear(); + + assert!(matches!( + allocations.get_mut_by_server(SERVER_V4), + MutAllocationRef::Disconnected + )); + } + + #[test] + fn replace_by_address_remembers_address() { + let mut allocations = Allocations::default(); + allocations.upsert( + 1, + RelaySocket::from(SERVER_V4), + Username::new("test".to_owned()).unwrap(), + "password".to_owned(), + Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), + SessionId::new(PublicKey::from([0u8; 32])), + ); + + allocations.upsert( + 1, + RelaySocket::from(SERVER2_V4), + Username::new("test".to_owned()).unwrap(), + "password".to_owned(), + Realm::new("firezone".to_owned()).unwrap(), + Instant::now(), + SessionId::new(PublicKey::from([0u8; 32])), + ); + + assert!(matches!( + allocations.get_mut_by_server(SERVER_V4), + MutAllocationRef::Disconnected + )); + } + + const SERVER_V4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 11111)); + const SERVER2_V4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 22222)); +} diff --git a/rust/connlib/snownet/src/node/connections.rs b/rust/connlib/snownet/src/node/connections.rs index a36d4ac14..c31d82a0b 100644 --- a/rust/connlib/snownet/src/node/connections.rs +++ b/rust/connlib/snownet/src/node/connections.rs @@ -8,13 +8,12 @@ use std::{ use anyhow::{Context as _, Result}; use boringtun::noise::Index; -use rand::{Rng, seq::IteratorRandom as _}; +use rand::Rng; use str0m::ice::IceAgent; use crate::{ ConnectionStats, Event, - allocation::Allocation, - node::{Connection, InitialConnection, add_local_candidate}, + node::{Connection, InitialConnection, add_local_candidate, allocations::Allocations}, }; pub struct Connections { @@ -100,16 +99,16 @@ where pub(crate) fn check_relays_available( &mut self, - allocations: &BTreeMap, + allocations: &Allocations, pending_events: &mut VecDeque>, rng: &mut impl Rng, ) { for (_, c) in self.iter_initial_mut() { - if allocations.contains_key(&c.relay) { + if allocations.contains(&c.relay) { continue; } - let Some(new_rid) = allocations.keys().copied().choose(rng) else { + let Some((new_rid, _)) = allocations.sample(rng) else { continue; }; @@ -118,11 +117,11 @@ where } for (cid, c) in self.iter_established_mut() { - if allocations.contains_key(&c.relay.id) { + if allocations.contains(&c.relay.id) { continue; // Our relay is still there, no problems. } - let Some((rid, allocation)) = allocations.iter().choose(rng) else { + let Some((rid, allocation)) = allocations.sample(rng) else { if !c.relay.logged_sample_failure { tracing::debug!(%cid, "Failed to sample new relay for connection"); } @@ -133,7 +132,7 @@ where tracing::info!(%cid, old = %c.relay.id, new = %rid, "Attempting to migrate connection to new relay"); - c.relay.id = *rid; + c.relay.id = rid; for candidate in allocation.current_relay_candidates() { add_local_candidate(cid, &mut c.agent, candidate, pending_events); diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index ef1edacc5..9ef0db5ff 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -242,3 +242,4 @@ cc 2e19d8524474163fb96a33e084832516e2e753a1c3e969f2436ace0850bcd74c cc 19b20eeea8590ac247e6534e42344cc4ae67a5d8e964f04e81b56f344d257c7b cc 46d17b15ff020c3f4982c43d85a342c5d10f5cec34a2316282ecfbe3a684573d cc a2746c27c8acc2f163989297aba492f9c767d7d4b72fef1cb9b84b65e5cbdfea +cc cb2df5e990c2f5c60f41f3b4dfe82736369a99dae8824bcc592831b3f366bc14 diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 736f1cfd4..58cd4f8c8 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -624,7 +624,7 @@ impl ReferenceState { } Transition::Idle => {} Transition::PartitionRelaysFromPortal => { - if state.drop_direct_client_traffic || state.client.port == 3478 { + if state.drop_direct_client_traffic { state.client.exec_mut(|client| client.reset_connections()); } }