diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c1d210069..b5e85d033 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -951,7 +951,7 @@ dependencies = [ [[package]] name = "boringtun" version = "0.6.1" -source = "git+https://github.com/firezone/boringtun?branch=master#bda7276b68396a591b454605eff30717e692a194" +source = "git+https://github.com/firezone/boringtun?branch=master#ed1de7c6ddf071d2895309f0fb153e9afb82fc99" dependencies = [ "aead", "base64 0.22.1", diff --git a/rust/connlib/snownet/src/index.rs b/rust/connlib/snownet/src/index.rs index ea93000a1..ed1354365 100644 --- a/rust/connlib/snownet/src/index.rs +++ b/rust/connlib/snownet/src/index.rs @@ -1,3 +1,4 @@ +use boringtun::noise::Index; use rand::Rng; // A basic linear-feedback shift register implemented as xorshift, used to @@ -36,7 +37,7 @@ impl IndexLfsr { } /// Generate the next value in the pseudorandom sequence - pub(crate) fn next(&mut self) -> u32 { + pub(crate) fn next(&mut self) -> Index { // 24-bit polynomial for randomness. This is arbitrarily chosen to // inject bitflips into the value. const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial @@ -44,6 +45,7 @@ impl IndexLfsr { let value = self.lfsr - 1; // lfsr will never have value of 0 self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY); assert!(self.lfsr != self.initial, "Too many peers created"); - value ^ self.mask + + Index::new_local(value ^ self.mask) } } diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 6053e1528..a159ae5df 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -5,7 +5,7 @@ use crate::utils::channel_data_packet_buffer; use anyhow::{Context, Result, anyhow}; use boringtun::noise::errors::WireGuardError; use boringtun::noise::{ - HandshakeResponse, Packet, PacketCookieReply, PacketData, Tunn, TunnResult, + HandshakeResponse, Index, Packet, PacketCookieReply, PacketData, Tunn, TunnResult, }; use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; @@ -291,12 +291,13 @@ where } let existing = self.connections.established.remove(&cid); + let index = self.index.next(); if let Some(existing) = existing { let current_local = existing.agent.local_credentials(); - tracing::info!(?current_local, new_local = ?local_creds, remote = ?remote_creds, "Replacing existing connection"); + tracing::info!(?current_local, new_local = ?local_creds, remote = ?remote_creds, %index, "Replacing existing connection"); } else { - tracing::info!(local = ?local_creds, remote = ?remote_creds, "Creating new connection"); + tracing::info!(local = ?local_creds, remote = ?remote_creds, %index, "Creating new connection"); } let selected_relay = self.sample_relay()?; @@ -314,8 +315,16 @@ where &mut self.pending_events, ); - let (index, connection) = - self.init_connection(cid, agent, remote, preshared_key, selected_relay, now, now); + let connection = self.init_connection( + cid, + agent, + remote, + preshared_key, + selected_relay, + index, + now, + now, + ); self.connections.insert_established(cid, index, connection); @@ -682,16 +691,16 @@ where remote: PublicKey, key: [u8; 32], relay: RId, + index: Index, intent_sent_at: Instant, now: Instant, - ) -> (u32, Connection) { + ) -> Connection { agent.handle_timeout(now); if self.allocations.is_empty() { tracing::warn!(%cid, "No TURN servers connected; connection may fail to establish"); } - let index = self.index.next(); let mut tunnel = Tunn::new_at( self.private_key.clone(), remote, @@ -714,31 +723,28 @@ where // until we have a WireGuard tunnel to send packets into. tunnel.set_rekey_attempt_time(Duration::from_secs(15)); - ( - index, - Connection { - agent, - tunnel, - next_wg_timer_update: now, - stats: Default::default(), - buffer: vec![0; ip_packet::MAX_FZ_PAYLOAD], - intent_sent_at, - signalling_completed_at: now, - remote_pub_key: remote, - relay: SelectedRelay { - id: relay, - logged_sample_failure: false, - }, - state: ConnectionState::Connecting { - wg_buffer: AllocRingBuffer::new(128), - ip_buffer: AllocRingBuffer::new(128), - }, - disconnected_at: None, - buffer_pool: self.buffer_pool.clone(), - last_proactive_handshake_sent_at: None, - first_handshake_completed_at: None, + Connection { + agent, + tunnel, + next_wg_timer_update: now, + stats: Default::default(), + buffer: vec![0; ip_packet::MAX_FZ_PAYLOAD], + intent_sent_at, + signalling_completed_at: now, + remote_pub_key: remote, + relay: SelectedRelay { + id: relay, + logged_sample_failure: false, }, - ) + state: ConnectionState::Connecting { + wg_buffer: AllocRingBuffer::new(128), + ip_buffer: AllocRingBuffer::new(128), + }, + disconnected_at: None, + buffer_pool: self.buffer_pool.clone(), + last_proactive_handshake_sent_at: None, + first_handshake_completed_at: None, + } } /// Tries to handle the packet using one of our [`Allocation`]s. @@ -895,33 +901,24 @@ where Err(e) => return ControlFlow::Break(Err(e)), }; - let Some((cid, connection)) = self + match self .connections .get_established_mut_by_public_key(handshake.peer_static_public) - else { - return ControlFlow::Break(Err(anyhow::Error::msg(format!( - "Received handshake for unknown public key: {}", - hex::encode(handshake.peer_static_public) - )))); - }; - - (cid, connection) + { + Ok(c) => c, + Err(e) => return ControlFlow::Break(Err(e)), + } } // For all other packets, grab the session index and look up the corresponding connection. Packet::HandshakeResponse(HandshakeResponse { receiver_idx, .. }) | Packet::PacketCookieReply(PacketCookieReply { receiver_idx, .. }) - | Packet::PacketData(PacketData { receiver_idx, .. }) => { - let Some((cid, connection)) = self - .connections - .get_established_mut_session_index(*receiver_idx) - else { - return ControlFlow::Break(Err(anyhow::Error::msg(format!( - "Received packet for unknown session index: {receiver_idx}" - )))); - }; - - (cid, connection) - } + | Packet::PacketData(PacketData { receiver_idx, .. }) => match self + .connections + .get_established_mut_session_index(Index::from_peer(*receiver_idx)) + { + Ok(c) => c, + Err(e) => return ControlFlow::Break(Err(e)), + }, }; let control_flow = conn.decapsulate( @@ -1076,12 +1073,14 @@ where &mut self.pending_events, ); - let (index, connection) = self.init_connection( + let index = self.index.next(); + let connection = self.init_connection( cid, agent, remote, *initial.session_key.expose_secret(), selected_relay, + index, initial.intent_sent_at, now, ); @@ -1147,12 +1146,14 @@ where &mut self.pending_events, ); - let (index, connection) = self.init_connection( + let index = self.index.next(); + let connection = self.init_connection( cid, agent, remote, *offer.session_key.expose_secret(), selected_relay, + index, now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time. now, ); @@ -1260,7 +1261,7 @@ struct Connections { initial: BTreeMap>, established: BTreeMap>, - established_by_wireguard_session_index: BTreeMap, + established_by_wireguard_session_index: BTreeMap, } impl Default for Connections { @@ -1350,22 +1351,16 @@ where fn insert_established( &mut self, id: TId, - index: u32, + index: Index, connection: Connection, ) -> Option> { let existing = self.established.insert(id, connection); - debug_assert_eq!( - index >> 24, - 0, - "The 8 most-significant bits should always be zero." - ); - // Remove previous mappings for connection. self.established_by_wireguard_session_index .retain(|_, c| c != &id); self.established_by_wireguard_session_index - .insert(index, id); + .insert(index.global(), id); existing } @@ -1409,28 +1404,31 @@ where fn get_established_mut_session_index( &mut self, - index: u32, - ) -> Option<(TId, &mut Connection)> { - // Drop the 8 least-significant bits. Those are used by boringtun to identify individual sessions. - // In order to find the original `Tunn` instance, we just want to compare the higher 24 bits. - let index = index >> 8; + index: Index, + ) -> Result<(TId, &mut Connection)> { + let id = self + .established_by_wireguard_session_index + .get(&index.global()) + .with_context(|| format!("No connection for with index {index}"))?; + let connection = self + .established + .get_mut(id) + .with_context(|| format!("No connection for ID {id}"))?; - let id = self.established_by_wireguard_session_index.get(&index)?; - let connection = self.established.get_mut(id)?; - - Some((*id, connection)) + Ok((*id, connection)) } fn get_established_mut_by_public_key( &mut self, key: [u8; 32], - ) -> Option<(TId, &mut Connection)> { + ) -> Result<(TId, &mut Connection)> { let (id, conn) = self .established .iter_mut() - .find(|(_, c)| c.tunnel.remote_static_public().as_bytes() == &key)?; + .find(|(_, c)| c.tunnel.remote_static_public().as_bytes() == &key) + .with_context(|| format!("No connection with public key {}", hex::encode(key)))?; - Some((*id, conn)) + Ok((*id, conn)) } fn iter_initial_mut(&mut self) -> impl Iterator)> {