diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 3f8e78471..b4dc10a99 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -460,77 +460,27 @@ where return Ok(None); } - let mut buffer = self.buffer_pool.pull(); - buffer.resize(ip_packet::MAX_FZ_PAYLOAD, 0); - - // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. - let Some(packet_len) = conn - .encapsulate(cid, packet, &mut buffer[4..], now) - .with_context(|| format!("cid={cid}"))? - .map(|p| p.len()) - // Mapping to len() here terminate the mutable borrow of buffer, allowing re-borrowing further down. - else { - return Ok(None); - }; - - let packet_start = 4; - let packet_end = 4 + packet_len; - let socket = match &mut conn.state { - ConnectionState::Connecting { buffered, .. } => { - buffered.push(buffer[packet_start..packet_end].to_vec()); - let num_buffered = buffered.len(); + ConnectionState::Connecting { ip_buffer, .. } => { + ip_buffer.push(packet); + let num_buffered = ip_buffer.len(); tracing::debug!(%num_buffered, %cid, "ICE is still in progress, buffering WG handshake"); return Ok(None); } - ConnectionState::Connected { peer_socket, .. } => peer_socket, - ConnectionState::Idle { peer_socket } => peer_socket, + ConnectionState::Connected { peer_socket, .. } => *peer_socket, + ConnectionState::Idle { peer_socket } => *peer_socket, ConnectionState::Failed => { return Err(anyhow!("Connection {cid} failed")); } }; - match *socket { - PeerSocket::PeerToPeer { - source, - dest: remote, - } - | PeerSocket::PeerToRelay { - source, - dest: remote, - } => { - buffer.copy_within(packet_start..packet_end, 0); - buffer.truncate(packet_len); + let maybe_transmit = conn + .encapsulate(cid, socket, packet, now, &mut self.allocations) + .with_context(|| format!("cid={cid}"))?; - Ok(Some(Transmit { - src: Some(source), - dst: remote, - payload: buffer, - })) - } - PeerSocket::RelayToPeer { relay, dest: peer } - | PeerSocket::RelayToRelay { relay, dest: peer } => { - let Some(allocation) = self.allocations.get_mut(&relay) else { - tracing::warn!(%relay, %cid, "No allocation"); - return Ok(None); - }; - let Some(encode_ok) = - allocation.encode_channel_data_header(peer, &mut buffer[..packet_end], now) - else { - return Ok(None); - }; - - buffer.truncate(packet_end); - - Ok(Some(Transmit { - src: None, - dst: encode_ok.socket, - payload: buffer, - })) - } - } + Ok(maybe_transmit) } /// Returns a pending [`Event`] from the pool. @@ -783,7 +733,8 @@ where remote_pub_key: remote, state: ConnectionState::Connecting { relay: Some(relay), - buffered: AllocRingBuffer::new(128), + wg_buffer: AllocRingBuffer::new(128), + ip_buffer: AllocRingBuffer::new(128), }, disconnected_at: None, possible_sockets: BTreeSet::default(), @@ -1716,9 +1667,12 @@ enum ConnectionState { /// /// This can happen if the remote's WG session initiation arrives at our socket before we nominate it. /// A session initiation requires a response that we must not drop, otherwise the connection setup experiences unnecessary delays. + wg_buffer: AllocRingBuffer>, + + /// Packets we are told to send whilst we are still running ICE. /// - /// It can also happen if we attempt to encapsulate a packet prior to the WireGuard handshake which triggers the creation of a WireGuard handshake initiation packet. - buffered: AllocRingBuffer>, + /// These need to be encrypted and sent once the tunnel is established. + ip_buffer: AllocRingBuffer, }, /// A socket has been nominated. Connected { @@ -1884,9 +1838,15 @@ enum PeerSocket { }, } +impl PeerSocket { + fn send_from_relay(&self) -> bool { + matches!(self, Self::RelayToPeer { .. } | Self::RelayToRelay { .. }) + } +} + impl Connection where - RId: PartialEq + Eq + Hash + fmt::Debug + Copy + Ord, + RId: PartialEq + Eq + Hash + fmt::Debug + fmt::Display + Copy + Ord, { /// Checks if we want to accept a packet from a certain address. /// @@ -2052,12 +2012,18 @@ where }; let old = match mem::replace(&mut self.state, ConnectionState::Failed) { - ConnectionState::Connecting { buffered, .. } => { - let num_buffered = buffered.len(); + ConnectionState::Connecting { + wg_buffer, + ip_buffer, + .. + } => { + tracing::debug!( + num_buffered = %wg_buffer.len(), + %cid, + "Flushing WireGuard packets buffered during ICE" + ); - tracing::debug!(%num_buffered, "Flushing packets buffered during ICE"); - - transmits.extend(buffered.into_iter().flat_map(|packet| { + transmits.extend(wg_buffer.into_iter().flat_map(|packet| { make_owned_transmit( remote_socket, &packet, @@ -2066,12 +2032,31 @@ where now, ) })); + + tracing::debug!( + num_buffered = %ip_buffer.len(), + %cid, + "Flushing IP packets buffered during ICE" + ); + transmits.extend(ip_buffer.into_iter().flat_map(|packet| { + let transmit = self + .encapsulate(cid, remote_socket, packet, now, allocations) + .inspect_err(|e| { + tracing::debug!( + %cid, + "Failed to encapsulate buffered IP packet: {e:#}" + ) + }) + .ok()??; + + Some(transmit) + })); + self.state = ConnectionState::Connected { peer_socket: remote_socket, last_incoming: now, last_outgoing: now, }; - None } ConnectionState::Connected { @@ -2206,28 +2191,74 @@ where }; } - fn encapsulate<'b, TId>( + fn encapsulate( &mut self, cid: TId, + socket: PeerSocket, packet: IpPacket, - buffer: &'b mut [u8], now: Instant, - ) -> Result> + allocations: &mut BTreeMap, + ) -> Result> where TId: fmt::Display, { self.state.on_outgoing(cid, &mut self.agent, &packet, now); - let len = match self.tunnel.encapsulate_at(packet.packet(), buffer, now) { - TunnResult::Done => return Ok(None), - TunnResult::Err(e) => return Err(anyhow::Error::new(e)), - TunnResult::WriteToNetwork(packet) => packet.len(), - TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { - unreachable!("never returned from encapsulate") - } - }; + let packet_start = if socket.send_from_relay() { 4 } else { 0 }; - Ok(Some(&buffer[..len])) + let mut buffer = self.buffer_pool.pull(); + buffer.resize(ip_packet::MAX_FZ_PAYLOAD, 0); + + let len = + match self + .tunnel + .encapsulate_at(packet.packet(), &mut buffer[packet_start..], now) + { + TunnResult::Done => return Ok(None), + TunnResult::Err(e) => return Err(anyhow::Error::new(e)), + TunnResult::WriteToNetwork(packet) => packet.len(), + TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { + unreachable!("never returned from encapsulate") + } + }; + + let packet_end = packet_start + len; + buffer.truncate(packet_end); + + match socket { + PeerSocket::PeerToPeer { + source, + dest: remote, + } + | PeerSocket::PeerToRelay { + source, + dest: remote, + } => Ok(Some(Transmit { + src: Some(source), + dst: remote, + payload: buffer, + })), + PeerSocket::RelayToPeer { relay, dest: peer } + | PeerSocket::RelayToRelay { relay, dest: peer } => { + let Some(allocation) = allocations.get_mut(&relay) else { + tracing::warn!(%relay, "No allocation"); + return Ok(None); + }; + let Some(encode_ok) = + allocation.encode_channel_data_header(peer, &mut buffer[..packet_end], now) + else { + return Ok(None); + }; + + buffer.truncate(packet_end); + + Ok(Some(Transmit { + src: None, + dst: encode_ok.socket, + payload: buffer, + })) + } + } } fn decapsulate( @@ -2284,16 +2315,16 @@ where // Overall, this results in a much nicer API for our caller and should not affect performance. TunnResult::WriteToNetwork(bytes) => { match &mut self.state { - ConnectionState::Connecting { buffered, .. } => { + ConnectionState::Connecting { wg_buffer, .. } => { tracing::debug!(%cid, "No socket has been nominated yet, buffering WG packet"); - buffered.push(bytes.to_owned()); + wg_buffer.push(bytes.to_owned()); while let TunnResult::WriteToNetwork(packet) = self.tunnel .decapsulate_at(None, &[], self.buffer.as_mut(), now) { - buffered.push(packet.to_owned()); + wg_buffer.push(packet.to_owned()); } } ConnectionState::Connected { peer_socket, .. }