diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index a00b7bae0..ea5cb8507 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -193,6 +193,7 @@ where /// /// To do that, we need to check all candidates of each allocation and compare their IP. /// The same relay might be reachable over IPv4 and IPv6. + #[must_use] fn same_relay_as_peer(&mut self, id: TId, candidate: &Candidate) -> Option<&mut Allocation> { self.allocations .iter_mut() @@ -272,6 +273,9 @@ where .get_established_mut(&connection) .ok_or(Error::NotConnected)?; + // Must bail early if we don't have a socket yet to avoid running into WG timeouts. + let socket = conn.peer_socket.ok_or(Error::NotConnected)?; + let (header, payload) = self.buffer.as_mut().split_at_mut(4); let packet_len = match conn.tunnel.encapsulate(packet.packet(), payload) { @@ -285,7 +289,7 @@ where let packet = &payload[..packet_len]; - match conn.peer_socket.ok_or(Error::NotConnected)? { + match socket { PeerSocket::Direct { dest: remote, source, @@ -321,6 +325,7 @@ where } /// Returns a pending [`Event`] from the pool. + #[must_use] pub fn poll_event(&mut self) -> Option> { let binding_events = self.bindings.iter_mut().flat_map(|(server, binding)| { iter::from_fn(|| binding.poll_event().map(|e| (*server, e))) @@ -411,6 +416,12 @@ where conn.peer_socket = Some(remote_socket); if is_first_connection { + tracing::info!(%id, "Starting wireguard handshake"); + + self.buffered_transmits.extend( + conn.force_handshake(&mut self.allocations, self.last_now), + ); + return Some(Event::ConnectionEstablished(id)); } } @@ -434,6 +445,7 @@ where /// /// This function only takes `&mut self` because it caches certain computations internally. /// The returned timestamp will **not** change unless other state is modified. + #[must_use] pub fn poll_timeout(&mut self) -> Option { let mut connection_timeout = None; @@ -513,6 +525,7 @@ where } /// Returns buffered data that needs to be sent on the socket. + #[must_use] pub fn poll_transmit(&mut self) -> Option> { for (_, conn) in self.connections.iter_established_mut() { if let Some(transmit) = conn.poll_transmit(&mut self.allocations, self.last_now) { @@ -535,6 +548,7 @@ where self.buffered_transmits.pop_front() } + #[must_use] fn init_connection( &mut self, mut agent: IceAgent, @@ -588,6 +602,7 @@ where Ok(()) } + #[must_use] fn bindings_try_handle( &mut self, from: SocketAddr, @@ -609,6 +624,7 @@ where } /// Tries to handle the packet using one of our [`Allocation`]s. + #[must_use] fn allocations_try_handle<'p>( &mut self, from: SocketAddr, @@ -645,6 +661,7 @@ where } } + #[must_use] fn agents_try_handle( &mut self, from: SocketAddr, @@ -677,6 +694,7 @@ where })) } + #[must_use] fn connections_try_handle<'b>( &mut self, from: SocketAddr, @@ -762,6 +780,7 @@ where /// Out of all configured STUN and TURN servers, the connection will only use the ones provided here. /// The returned [`Offer`] must be passed to the remote via a signalling channel. #[tracing::instrument(level = "info", skip_all, fields(%id))] + #[must_use] pub fn new_connection( &mut self, id: TId, @@ -812,11 +831,13 @@ where debug_assert!(existing.is_none()); + tracing::info!("Establishing new connection"); + params } /// Accept an [`Answer`] from the remote for a connection previously created via [`Node::new_connection`]. - #[tracing::instrument(level = "debug", skip_all, fields(%id, remote = %hex::encode(remote.as_bytes())))] + #[tracing::instrument(level = "info", skip_all, fields(%id))] pub fn accept_answer(&mut self, id: TId, remote: PublicKey, answer: Answer) { let Some(initial) = self.connections.initial.remove(&id) else { tracing::debug!("No initial connection state, ignoring answer"); // This can happen if the connection setup timed out. @@ -846,6 +867,8 @@ where let existing = self.connections.established.insert(id, connection); + tracing::info!(remote = %hex::encode(remote.as_bytes()), "Signalling protocol completed"); + debug_assert!(existing.is_none()); } } @@ -859,6 +882,7 @@ where /// Out of all configured STUN and TURN servers, the connection will only use the ones provided here. /// The returned [`Answer`] must be passed to the remote via a signalling channel. #[tracing::instrument(level = "info", skip_all, fields(%id))] + #[must_use] pub fn accept_connection( &mut self, id: TId, @@ -916,6 +940,8 @@ where debug_assert!(existing.is_none()); + tracing::info!("Created new connection"); + answer } } @@ -1271,6 +1297,7 @@ impl Connection { /// /// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete. /// We already want to accept that traffic and not throw it away. + #[must_use] fn accepts(&self, addr: SocketAddr) -> bool { let from_connected_remote = self.peer_socket.is_some_and(|r| match r { PeerSocket::Direct { dest, .. } => dest == addr, @@ -1304,6 +1331,7 @@ impl Connection { } } + #[must_use] fn poll_timeout(&mut self) -> Option { let agent_timeout = self.agent.poll_timeout(); let next_wg_timer = Some(self.next_timer_update); @@ -1311,6 +1339,7 @@ impl Connection { earliest(agent_timeout, next_wg_timer) } + #[must_use] fn poll_transmit( &mut self, allocations: &mut HashMap, @@ -1392,6 +1421,7 @@ impl Connection { Ok(None) } + #[must_use] fn encapsulate( &self, message: &[u8], @@ -1412,4 +1442,26 @@ impl Connection { } } } + + #[must_use] + fn force_handshake( + &mut self, + allocations: &mut HashMap, + now: Instant, + ) -> Option> { + /// [`boringtun`] requires us to pass buffers in where it can construct its packets. + /// + /// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`]. + const MAX_SCRATCH_SPACE: usize = 148; + + let mut buf = [0u8; MAX_SCRATCH_SPACE]; + + let TunnResult::WriteToNetwork(bytes) = + self.tunnel.format_handshake_initiation(&mut buf, true) + else { + return None; + }; + + self.encapsulate(bytes, allocations, now) + } }