diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 99c7ba56c..980e64803 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -2,7 +2,6 @@ use crate::{ backoff::{self, ExponentialBackoff}, node::{SessionId, Transmit}, utils::earliest, - EncryptedPacket, }; use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; @@ -11,7 +10,6 @@ use hex_display::HexDisplayExt as _; use rand::random; use ringbuffer::{AllocRingBuffer, RingBuffer as _}; use std::{ - borrow::Cow, collections::{BTreeMap, VecDeque}, net::{SocketAddr, SocketAddrV4, SocketAddrV6}, time::{Duration, Instant}, @@ -763,40 +761,32 @@ impl Allocation { ); } - pub fn encode_to_encrypted_packet( - &self, - peer: SocketAddr, - mut buffer: lockfree_object_pool::SpinLockOwnedReusable>, - buffer_len: usize, - now: Instant, - ) -> Option { - let packet_len = buffer_len - 4; - - let channel_number = self.channel_bindings.connected_channel_to_peer(peer, now)?; - crate::channel_data::encode_header_to_slice(&mut buffer[..4], channel_number, packet_len); - - Some(EncryptedPacket { - src: None, - dst: self.active_socket?, - packet_start: 0, - packet_len: buffer_len, - buffer, - }) - } - - pub fn encode_to_owned_transmit( + pub fn encode_channel_data_header( &mut self, peer: SocketAddr, - packet: &[u8], + buffer: &mut [u8], now: Instant, - ) -> Option> { - let channel_number = self.channel_bindings.connected_channel_to_peer(peer, now)?; - let channel_data = crate::channel_data::encode(channel_number, packet); + ) -> Option { + let active_socket = self.active_socket?; + let payload_length = buffer.len() - 4; - Some(Transmit { - src: None, - dst: self.active_socket?, - payload: Cow::Owned(channel_data), + let channel_number = match self.channel_bindings.connected_channel_to_peer(peer, now) { + Some(cn) => cn, + None => { + tracing::debug!(%peer, %active_socket, "No channel to peer, binding new one"); + self.bind_channel(peer, now); + + return None; + } + }; + crate::channel_data::encode_header_to_slice( + &mut buffer[..4], + channel_number, + payload_length, + ); + + Some(EncodeOk { + socket: active_socket, }) } @@ -1083,6 +1073,10 @@ impl Allocation { } } +pub struct EncodeOk { + pub socket: SocketAddr, +} + fn authenticate(message: Message, credentials: &Credentials) -> Message { let attributes = message .attributes() @@ -1511,6 +1505,8 @@ fn display_attr(attr: &Attribute) -> String { #[cfg(test)] mod tests { + use crate::utils::channel_data_packet_buffer; + use super::*; use std::{ iter, @@ -1624,7 +1620,8 @@ mod tests { let channel = channel_bindings.new_channel_to_peer(PEER1, start).unwrap(); channel_bindings.set_confirmed(channel, start + Duration::from_secs(1)); - let packet = crate::channel_data::encode(channel, b"foobar"); + let mut packet = channel_data_packet_buffer(b"foobar"); + crate::channel_data::encode_header_to_slice(&mut packet[..4], channel, 6); let (peer, payload) = channel_bindings .try_decode(&packet, start + Duration::from_secs(2)) .unwrap(); @@ -1641,7 +1638,8 @@ mod tests { let channel = channel_bindings.new_channel_to_peer(PEER1, start).unwrap(); channel_bindings.set_confirmed(channel, start + Duration::from_secs(1)); - let packet = crate::channel_data::encode(channel, b"foobar"); + let mut packet = channel_data_packet_buffer(b"foobar"); + crate::channel_data::encode_header_to_slice(&mut packet[..4], channel, 6); channel_bindings .try_decode(&packet, start + Duration::from_secs(2)) .unwrap(); @@ -1788,7 +1786,7 @@ mod tests { } #[test] - fn does_not_relay_to_with_unbound_channel() { + fn does_relay_to_with_bound_channel() { let mut allocation = Allocation::for_test_ip4(Instant::now()) .with_binding_response(PEER1) .with_allocate_response(&[RELAY_ADDR_IP4]); @@ -1800,25 +1798,26 @@ mod tests { Instant::now(), ); - let transmit = allocation - .encode_to_owned_transmit(PEER2_IP4, b"foobar", Instant::now()) + let mut buffer = channel_data_packet_buffer(b"foobar"); + let encode_ok = allocation + .encode_channel_data_header(PEER2_IP4, &mut buffer, Instant::now()) .unwrap(); - assert_eq!(&transmit.payload[4..], b"foobar"); - assert_eq!(transmit.src, None); - assert_eq!(transmit.dst, RELAY_V4.into()); + assert_eq!(encode_ok.socket, RELAY_V4.into()); } #[test] - fn does_relay_to_with_bound_channel() { + fn does_not_relay_to_with_unbound_channel() { let mut allocation = Allocation::for_test_ip4(Instant::now()) .with_binding_response(PEER1) .with_allocate_response(&[RELAY_ADDR_IP4]); allocation.bind_channel(PEER2_IP4, Instant::now()); - let message = allocation.encode_to_owned_transmit(PEER2_IP4, b"foobar", Instant::now()); + let mut buffer = channel_data_packet_buffer(b"foobar"); + let encode_ok = + allocation.encode_channel_data_header(PEER2_IP4, &mut buffer, Instant::now()); - assert!(message.is_none()) + assert!(encode_ok.is_none()) } #[test] @@ -2105,7 +2104,8 @@ mod tests { Instant::now(), ); - let msg = allocation.encode_to_owned_transmit(PEER2_IP4, b"foobar", Instant::now()); + let mut packet = channel_data_packet_buffer(b"foobar"); + let msg = allocation.encode_channel_data_header(PEER2_IP4, &mut packet, Instant::now()); assert!(msg.is_some(), "expect to have a channel to peer"); allocation.refresh_with_same_credentials(); @@ -2113,7 +2113,8 @@ mod tests { let refresh = allocation.next_message().unwrap(); allocation.handle_test_input_ip4(&allocation_mismatch(&refresh), Instant::now()); - let msg = allocation.encode_to_owned_transmit(PEER2_IP4, b"foobar", Instant::now()); + let mut packet = channel_data_packet_buffer(b"foobar"); + let msg = allocation.encode_channel_data_header(PEER2_IP4, &mut packet, Instant::now()); assert!(msg.is_none(), "expect to no longer have a channel to peer"); } diff --git a/rust/connlib/snownet/src/channel_data.rs b/rust/connlib/snownet/src/channel_data.rs index 7d15d2eba..ea1465180 100644 --- a/rust/connlib/snownet/src/channel_data.rs +++ b/rust/connlib/snownet/src/channel_data.rs @@ -1,4 +1,4 @@ -use bytes::{BufMut, BytesMut}; +use bytes::BufMut; use std::io; const HEADER_LEN: usize = 4; @@ -36,14 +36,6 @@ pub fn decode(data: &[u8]) -> Result<(u16, &[u8]), io::Error> { Ok((channel_number, &payload[..length])) } -pub fn encode(channel: u16, data: &[u8]) -> Vec { - debug_assert!(channel > 0x400); - debug_assert!(channel < 0x7FFF); - debug_assert!(data.len() <= u16::MAX as usize); - - to_bytes(channel, data.len() as u16, data) -} - /// Encode the channel data header (number + length) to the given slice. /// /// Returns the total length of the packet (i.e. the encoded header + data). @@ -59,13 +51,3 @@ pub fn encode_header_to_slice(mut slice: &mut [u8], channel: u16, payload_length HEADER_LEN + payload_length } - -fn to_bytes(channel: u16, len: u16, payload: &[u8]) -> Vec { - let mut message = BytesMut::with_capacity(HEADER_LEN + (len as usize)); - - message.put_u16(channel); - message.put_u16(len); - message.put_slice(payload); - - message.freeze().into() -} diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 3a18ff137..6fe0d0ac0 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -2,7 +2,7 @@ use crate::allocation::{self, Allocation, RelaySocket, Socket}; use crate::candidate_set::CandidateSet; use crate::index::IndexLfsr; use crate::stats::{ConnectionStats, NodeStats}; -use crate::utils::earliest; +use crate::utils::{channel_data_packet_buffer, earliest}; use boringtun::noise::errors::WireGuardError; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::PublicKey; @@ -485,18 +485,23 @@ where buffer, })), PeerSocket::Relay { relay, dest: peer } => { - let Some(allocation) = self.allocations.get(&relay) else { + let Some(allocation) = self.allocations.get_mut(&relay) else { tracing::warn!(%relay, "No allocation"); return Ok(None); }; - let Some(enc_packet) = - allocation.encode_to_encrypted_packet(peer, buffer, packet_end, now) + let Some(encode_ok) = + allocation.encode_channel_data_header(peer, &mut buffer[..packet_end], now) else { - tracing::warn!(%peer, "No channel"); return Ok(None); }; - Ok(Some(enc_packet)) + Ok(Some(EncryptedPacket { + src: None, + dst: encode_ok.socket, + packet_start: 0, + packet_len: packet_end, + buffer, + })) } } } @@ -1321,37 +1326,6 @@ where } } -/// Wraps the message as a channel data message via the relay, iff: -/// -/// - `relay` is in fact a relay -/// - We have an allocation on the relay -/// - There is a channel bound to the provided peer -fn encode_as_channel_data( - relay: RId, - dest: SocketAddr, - contents: &[u8], - allocations: &mut BTreeMap, - now: Instant, -) -> Result, EncodeError> -where - RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug, -{ - let allocation = allocations - .get_mut(&relay) - .ok_or(EncodeError::NoAllocation)?; - let transmit = allocation - .encode_to_owned_transmit(dest, contents, now) - .ok_or(EncodeError::NoChannel)?; - - Ok(transmit) -} - -#[derive(Debug)] -enum EncodeError { - NoAllocation, - NoChannel, -} - fn add_local_candidate( id: TId, agent: &mut IceAgent, @@ -1956,7 +1930,7 @@ where while let Some(transmit) = self.agent.poll_transmit() { let source = transmit.source; let dst = transmit.destination; - let packet = transmit.contents; + let stun_packet = transmit.contents; // Check if `str0m` wants us to send from a "remote" socket, i.e. one that we allocated with a relay. let allocation = allocations @@ -1964,27 +1938,35 @@ where .find(|(_, allocation)| allocation.has_socket(source)); let Some((relay, allocation)) = allocation else { - self.stats.stun_bytes_to_peer_direct += packet.len(); + self.stats.stun_bytes_to_peer_direct += stun_packet.len(); // `source` did not match any of our allocated sockets, must be a local one then! transmits.push_back(Transmit { src: Some(source), dst, - payload: Cow::Owned(packet.into()), + payload: Cow::Owned(stun_packet.into()), }); continue; }; + let mut data_channel_packet = channel_data_packet_buffer(&stun_packet); + // Payload should be sent from a "remote socket", let's wrap it in a channel data message! - let Some(channel_data) = allocation.encode_to_owned_transmit(dst, &packet, now) else { + let Some(encode_ok) = + allocation.encode_channel_data_header(dst, &mut data_channel_packet, now) + else { // Unlikely edge-case, drop the packet and continue. tracing::trace!(%relay, peer = %dst, "Dropping packet because allocation does not offer a channel to peer"); continue; }; - self.stats.stun_bytes_to_peer_relayed += channel_data.payload.len(); + self.stats.stun_bytes_to_peer_relayed += data_channel_packet.len(); - transmits.push_back(channel_data); + transmits.push_back(Transmit { + src: None, + dst: encode_ok.socket, + payload: Cow::Owned(data_channel_packet), + }); } } @@ -2166,7 +2148,16 @@ where payload: Cow::Owned(message.into()), }, PeerSocket::Relay { relay, dest: peer } => { - encode_as_channel_data(relay, peer, message, allocations, now).ok()? + let allocation = allocations.get_mut(&relay)?; + + let mut buffer = channel_data_packet_buffer(message); + let encode_ok = allocation.encode_channel_data_header(peer, &mut buffer, now)?; + + Transmit { + src: None, + dst: encode_ok.socket, + payload: Cow::Owned(buffer), + } } }; diff --git a/rust/connlib/snownet/src/utils.rs b/rust/connlib/snownet/src/utils.rs index b26639545..b37e6c6bb 100644 --- a/rust/connlib/snownet/src/utils.rs +++ b/rust/connlib/snownet/src/utils.rs @@ -8,3 +8,7 @@ pub fn earliest(left: Option, right: Option) -> Option Some(right), } } + +pub fn channel_data_packet_buffer(payload: &[u8]) -> Vec { + [&[0u8; 4] as &[u8], payload].concat() +}