From 1222be8fc95ab7791ac054d3ee65bc5c53acec6b Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 4 Aug 2025 23:35:48 +1000 Subject: [PATCH] fix(snownet): de-multiplex packets based on WG session index (#10109) Right now, `snownet` de-multiplexes WireGuard packets based on their source tuple (IP + port) to the _first_ connection that would like to handle this traffic. What appears to be happening based on observation from customer logs is that we sometimes dispatch the traffic to the wrong connection. The WireGuard packet format uses session indices to declare, which session a packet is for. The local session index is selected during the handshake for a particular session. By associating the different session indices (we can have up to 8 in parallel per peer) with our Firezone-specific connection ID, we can change our de-multiplexing scheme to uses these indices instead of the source tuple. This is especially important for Gateways as those talk to multiple different clients. The session index is a 32-bit integer where the top 24 bits identify the connection and the bottom 8 bits are used in a round-robin fashion to identify individual sessions within the connection. Thus, to find the correct connection, we right-shift the session index of an incoming packet to arrive back at the 24-bit connection identifier. In environments with a limited number of ports outside the NAT, a connection from a new Client may come from a source tuple of a previous Client. In such a case, we'd dispatch the packets to the wrong connection, causing the Client to not be able to handshake a tunnel. --- .github/workflows/ci.yml | 2 +- rust/Cargo.lock | 2 +- rust/connlib/snownet/src/node.rs | 258 ++++++++++++------- website/src/components/Changelog/Gateway.tsx | 7 +- 4 files changed, 169 insertions(+), 100 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 75b670216..9adba5404 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -316,7 +316,7 @@ jobs: run: | # We need to increase the log level to make sure that they don't hold off storm of packets # generated by UDP tests. Wire is especially chatty. - sed -i 's/^\(\s*\)RUST_LOG:.*$/\1RUST_LOG: wire=error,opentelemetry_sdk=error,info/' docker-compose.yml + sed -i 's/^\(\s*\)RUST_LOG:.*$/\1RUST_LOG: wire=error,opentelemetry_sdk=error,debug/' docker-compose.yml grep RUST_LOG docker-compose.yml # Start services in the same order each time for the tests diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 9f2195723..ccd469825 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -951,7 +951,7 @@ dependencies = [ [[package]] name = "boringtun" version = "0.6.1" -source = "git+https://github.com/firezone/boringtun?branch=master#84a33359f4281e29139f12dafbd15e1da97df1b4" +source = "git+https://github.com/firezone/boringtun?branch=master#bda7276b68396a591b454605eff30717e692a194" dependencies = [ "aead", "base64 0.22.1", diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 0002a5a40..83a2d3f1c 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -4,7 +4,9 @@ use crate::stats::{ConnectionStats, NodeStats}; use crate::utils::channel_data_packet_buffer; use anyhow::{Context, Result, anyhow}; use boringtun::noise::errors::WireGuardError; -use boringtun::noise::{Tunn, TunnResult}; +use boringtun::noise::{ + HandshakeResponse, Packet, PacketCookieReply, PacketData, Tunn, TunnResult, +}; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use bufferpool::{Buffer, BufferPool}; @@ -312,10 +314,10 @@ where &mut self.pending_events, ); - let connection = + let (index, connection) = self.init_connection(cid, agent, remote, preshared_key, selected_relay, now, now); - self.connections.established.insert(cid, connection); + self.connections.insert_established(cid, index, connection); Ok(()) } @@ -682,19 +684,20 @@ where relay: RId, intent_sent_at: Instant, now: Instant, - ) -> Connection { + ) -> (u32, Connection) { agent.handle_timeout(now); if self.allocations.is_empty() { tracing::warn!(%cid, "No TURN servers connected; connection may fail to establish"); } + let index = self.index.next(); let mut tunnel = Tunn::new_at( self.private_key.clone(), remote, Some(key), None, - self.index.next(), + index, Some(self.rate_limiter.clone()), self.rng.next_u64(), now, @@ -711,25 +714,28 @@ where // until we have a WireGuard tunnel to send packets into. tunnel.set_rekey_attempt_time(Duration::from_secs(15)); - Connection { - agent, - tunnel, - next_wg_timer_update: now, - stats: Default::default(), - buffer: vec![0; ip_packet::MAX_FZ_PAYLOAD], - intent_sent_at, - signalling_completed_at: now, - remote_pub_key: remote, - relay, - state: ConnectionState::Connecting { - wg_buffer: AllocRingBuffer::new(128), - ip_buffer: AllocRingBuffer::new(128), + ( + index, + Connection { + agent, + tunnel, + next_wg_timer_update: now, + stats: Default::default(), + buffer: vec![0; ip_packet::MAX_FZ_PAYLOAD], + intent_sent_at, + signalling_completed_at: now, + remote_pub_key: remote, + relay, + state: ConnectionState::Connecting { + wg_buffer: AllocRingBuffer::new(128), + ip_buffer: AllocRingBuffer::new(128), + }, + disconnected_at: None, + buffer_pool: self.buffer_pool.clone(), + last_proactive_handshake_sent_at: None, + first_handshake_completed_at: None, }, - disconnected_at: None, - possible_sockets: BTreeSet::default(), - buffer_pool: self.buffer_pool.clone(), - last_proactive_handshake_sent_at: None, - } + ) } /// Tries to handle the packet using one of our [`Allocation`]s. @@ -865,51 +871,83 @@ where packet: &[u8], now: Instant, ) -> ControlFlow, (TId, IpPacket)> { - for (cid, conn) in self.connections.iter_established_mut() { - if !conn.accepts(&from) { - continue; + // If the packet is not a WireGuard packet, bail early. + let Ok(parsed_packet) = boringtun::noise::Tunn::parse_incoming_packet(packet) else { + tracing::debug!(packet = %hex::encode(packet)); + + return ControlFlow::Break(Err(anyhow::Error::msg("Not a WireGuard packet"))); + }; + + let (cid, conn) = match &parsed_packet { + // When receiving a handshake, we need to look-up the peer by its public key because we don't have a session-index mapping yet. + Packet::HandshakeInit(handshake_init) => { + let handshake = match boringtun::noise::handshake::parse_handshake_anon( + &self.private_key, + &self.public_key, + handshake_init, + ) + .context("Failed to parse handshake init") + { + Ok(handshake) => handshake, + Err(e) => return ControlFlow::Break(Err(e)), + }; + + let Some((cid, connection)) = self + .connections + .get_established_mut_by_public_key(handshake.peer_static_public) + else { + return ControlFlow::Break(Err(anyhow::Error::msg(format!( + "Received handshake for unknown public key: {}", + hex::encode(handshake.peer_static_public) + )))); + }; + + (cid, connection) } + // For all other packets, grab the session index and look up the corresponding connection. + Packet::HandshakeResponse(HandshakeResponse { receiver_idx, .. }) + | Packet::PacketCookieReply(PacketCookieReply { receiver_idx, .. }) + | Packet::PacketData(PacketData { receiver_idx, .. }) => { + let Some((cid, connection)) = self + .connections + .get_established_mut_session_index(*receiver_idx) + else { + return ControlFlow::Break(Err(anyhow::Error::msg(format!( + "Received packet for unknown session index: {receiver_idx}" + )))); + }; - let handshake_complete_before_decapsulate = conn.wg_handshake_complete(now); - - let control_flow = conn.decapsulate( - cid, - from.ip(), - packet, - &mut self.allocations, - &mut self.buffered_transmits, - now, - ); - - let handshake_complete_after_decapsulate = conn.wg_handshake_complete(now); - - // I can't think of a better way to detect this ... - if !handshake_complete_before_decapsulate && handshake_complete_after_decapsulate { - tracing::info!(%cid, duration_since_intent = ?conn.duration_since_intent(now), "Completed wireguard handshake"); - - self.pending_events - .push_back(Event::ConnectionEstablished(cid)) + (cid, connection) } + }; - return match control_flow { - ControlFlow::Continue(c) => ControlFlow::Continue((cid, c)), - ControlFlow::Break(b) => ControlFlow::Break( - b.with_context(|| format!("cid={cid} length={}", packet.len())), - ), - }; + let control_flow = conn.decapsulate( + cid, + from.ip(), + packet, + &mut self.allocations, + &mut self.buffered_transmits, + now, + ); + + if let ControlFlow::Break(Ok(())) = &control_flow + && conn.first_handshake_completed_at.is_none() + && matches!( + parsed_packet, + Packet::HandshakeInit(_) | Packet::HandshakeResponse(_) + ) + { + conn.first_handshake_completed_at = Some(now); + + tracing::info!(%cid, duration_since_intent = ?conn.duration_since_intent(now), "Completed wireguard handshake"); + + self.pending_events + .push_back(Event::ConnectionEstablished(cid)) } - if crate::is_wireguard(packet) { - tracing::trace!( - "Packet was a WireGuard packet but no connection handled it. Already disconnected?" - ); - - return ControlFlow::Break(Ok(())); - } - - tracing::debug!(packet = %hex::encode(packet)); - - ControlFlow::Break(Err(anyhow!("Packet has unknown format"))) + control_flow + .map_continue(|c| (cid, c)) + .map_break(|b| b.with_context(|| format!("cid={cid} length={}", packet.len()))) } fn allocations_drain_events(&mut self) { @@ -1035,7 +1073,7 @@ where &mut self.pending_events, ); - let connection = self.init_connection( + let (index, connection) = self.init_connection( cid, agent, remote, @@ -1046,7 +1084,7 @@ where ); let duration_since_intent = connection.duration_since_intent(now); - let existing = self.connections.established.insert(cid, connection); + let existing = self.connections.insert_established(cid, index, connection); tracing::info!(?duration_since_intent, remote = %hex::encode(remote.as_bytes()), "Signalling protocol completed"); @@ -1106,7 +1144,7 @@ where &mut self.pending_events, ); - let connection = self.init_connection( + let (index, connection) = self.init_connection( cid, agent, remote, @@ -1115,7 +1153,7 @@ where now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time. now, ); - let existing = self.connections.established.insert(cid, connection); + let existing = self.connections.insert_established(cid, index, connection); debug_assert!(existing.is_none()); @@ -1218,6 +1256,8 @@ fn generate_optimistic_candidates(agent: &mut IceAgent) { struct Connections { initial: BTreeMap>, established: BTreeMap>, + + established_by_wireguard_session_index: BTreeMap, } impl Default for Connections { @@ -1225,6 +1265,7 @@ impl Default for Connections { Self { initial: Default::default(), established: Default::default(), + established_by_wireguard_session_index: Default::default(), } } } @@ -1247,6 +1288,8 @@ where self.established.retain(|id, conn| { if conn.is_failed() { events.push_back(Event::ConnectionFailed(*id)); + self.established_by_wireguard_session_index + .retain(|_, c| c != id); return false; } @@ -1297,6 +1340,29 @@ where self.established.iter().map(move |(id, c)| (*id, c.stats)) } + fn insert_established( + &mut self, + id: TId, + index: u32, + connection: Connection, + ) -> Option> { + let existing = self.established.insert(id, connection); + + debug_assert_eq!( + index >> 24, + 0, + "The 8 most-significant bits should always be zero." + ); + + // Remove previous mappings for connection. + self.established_by_wireguard_session_index + .retain(|_, c| c != &id); + self.established_by_wireguard_session_index + .insert(index, id); + + existing + } + fn agent_mut(&mut self, id: TId) -> Option<(&mut IceAgent, RId)> { let maybe_initial_connection = self.initial.get_mut(&id).map(|i| (&mut i.agent, i.relay)); let maybe_established_connection = self @@ -1334,6 +1400,32 @@ where self.established.get_mut(id) } + fn get_established_mut_session_index( + &mut self, + index: u32, + ) -> Option<(TId, &mut Connection)> { + // Drop the 8 least-significant bits. Those are used by boringtun to identify individual sessions. + // In order to find the original `Tunn` instance, we just want to compare the higher 24 bits. + let index = index >> 8; + + let id = self.established_by_wireguard_session_index.get(&index)?; + let connection = self.established.get_mut(id)?; + + Some((*id, connection)) + } + + fn get_established_mut_by_public_key( + &mut self, + key: [u8; 32], + ) -> Option<(TId, &mut Connection)> { + let (id, conn) = self + .established + .iter_mut() + .find(|(_, c)| c.tunnel.remote_static_public().as_bytes() == &key)?; + + Some((*id, conn)) + } + fn iter_initial_mut(&mut self) -> impl Iterator)> { self.initial.iter_mut().map(|(id, conn)| (*id, conn)) } @@ -1353,6 +1445,7 @@ where fn clear(&mut self) { self.initial.clear(); self.established.clear(); + self.established_by_wireguard_session_index.clear(); } fn iter_ids(&self) -> impl Iterator + '_ { @@ -1608,12 +1701,10 @@ struct Connection { state: ConnectionState, disconnected_at: Option, - /// Socket addresses from which we might receive data (even before we are connected). - possible_sockets: BTreeSet, - stats: ConnectionStats, intent_sent_at: Instant, signalling_completed_at: Instant, + first_handshake_completed_at: Option, buffer: Vec, @@ -1837,31 +1928,6 @@ impl Connection where RId: PartialEq + Eq + Hash + fmt::Debug + fmt::Display + Copy + Ord, { - /// Checks if we want to accept a packet from a certain address. - /// - /// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete. - /// We already want to accept that traffic and not throw it away. - #[must_use] - fn accepts(&self, addr: &SocketAddr) -> bool { - let from_nominated = match &self.state { - ConnectionState::Idle { peer_socket } - | ConnectionState::Connected { peer_socket, .. } => match &peer_socket { - PeerSocket::PeerToPeer { dest, .. } | PeerSocket::PeerToRelay { dest, .. } => { - dest == addr - } - PeerSocket::RelayToPeer { dest: remote, .. } - | PeerSocket::RelayToRelay { dest: remote, .. } => remote == addr, - }, - ConnectionState::Failed | ConnectionState::Connecting { .. } => false, - }; - - from_nominated || self.possible_sockets.contains(addr) - } - - fn wg_handshake_complete(&self, now: Instant) -> bool { - self.tunnel.time_since_last_handshake_at(now).is_some() - } - fn duration_since_intent(&self, now: Instant) -> Duration { now.duration_since(self.intent_sent_at) } @@ -1950,9 +2016,7 @@ where while let Some(event) = self.agent.poll_event() { match event { - IceAgentEvent::DiscoveredRecv { source, .. } => { - self.possible_sockets.insert(source); - } + IceAgentEvent::DiscoveredRecv { .. } => {} IceAgentEvent::IceConnectionStateChange(IceConnectionState::Disconnected) => { tracing::debug!(grace_period = ?DISCONNECT_TIMEOUT, "Received ICE disconnect"); diff --git a/website/src/components/Changelog/Gateway.tsx b/website/src/components/Changelog/Gateway.tsx index cc440bac4..425cd7f6d 100644 --- a/website/src/components/Changelog/Gateway.tsx +++ b/website/src/components/Changelog/Gateway.tsx @@ -22,7 +22,12 @@ export default function Gateway() { return ( - + + + Fixes an issue where connections would fail to establish in + environments with a limited number of ports on the NAT. + + Fixes an issue where a Client could not establish a connection unless