From e64189c2dee9f1ce3b5a179bfdf2ca620352adba Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 9 Mar 2024 20:42:27 +1100 Subject: [PATCH] feat(snownet): log duration since intent after WG handshake completes (#3991) Preceded by some refactoring, this PR adds a log line with a very important metric: Time since connection intent after WG handshake. This is the equivalent of time-to-first-byte, i.e. how long the user needs to wait to actually send their first application data after they've tried for the firs time (and generated an intent). --- rust/connlib/snownet/src/node.rs | 310 +++++++++++------- rust/connlib/snownet/tests/lib.rs | 28 +- rust/connlib/tunnel/src/client.rs | 18 +- .../tunnel/src/control_protocol/client.rs | 8 +- rust/snownet-tests/src/main.rs | 2 + 5 files changed, 232 insertions(+), 134 deletions(-) diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index e96ab20e2..670301c48 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -284,17 +284,10 @@ where let (header, payload) = self.buffer.as_mut().split_at_mut(4); - let packet_len = match conn.tunnel.encapsulate(packet.packet(), payload) { - TunnResult::Done => return Ok(None), - TunnResult::Err(e) => return Err(Error::Encapsulate(e)), - TunnResult::WriteToNetwork(packet) => packet.len(), - TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { - unreachable!("never returned from encapsulate") - } + let Some(packet) = conn.encapsulate(packet.packet(), payload)? else { + return Ok(None); }; - let packet = &payload[..packet_len]; - match socket { PeerSocket::Direct { dest: remote, @@ -415,14 +408,11 @@ where if conn.peer_socket != Some(remote_socket) { let is_first_connection = conn.peer_socket.is_none(); - tracing::info!(old = ?conn.peer_socket, new = ?remote_socket, "Updating remote socket"); + tracing::info!(old = ?conn.peer_socket, new = ?remote_socket, duration_since_intent = ?conn.duration_since_intent(self.last_now), "Updating remote socket"); conn.peer_socket = Some(remote_socket); conn.invalidate_candiates(); - - tracing::info!(%id, "Sending wireguard handshake"); - self.buffered_transmits - .extend(conn.force_handshake(&mut self.allocations, self.last_now)); + conn.force_handshake(&mut self.allocations, self.last_now); if is_first_connection { return Some(Event::ConnectionEstablished(id)); @@ -476,9 +466,7 @@ where for (id, c) in self.connections.iter_established_mut() { match c.handle_timeout(now, &mut self.allocations) { - Ok(Some(transmit)) => { - self.buffered_transmits.push_back(transmit); - } + Ok(()) => {} Err(ConnectionError::Wireguard(WireGuardError::ConnectionExpired)) => { expired_connections.push(id); } @@ -488,7 +476,6 @@ where Err(ConnectionError::Wireguard(e)) => { tracing::warn!(%id, ?e); } - _ => {} }; } @@ -567,6 +554,7 @@ where } #[must_use] + #[allow(clippy::too_many_arguments)] fn init_connection( &mut self, mut agent: IceAgent, @@ -574,6 +562,7 @@ where key: [u8; 32], allowed_stun_servers: HashSet, allowed_turn_servers: HashSet, + intent_sent_at: Instant, now: Instant, ) -> Connection { agent.handle_timeout(self.last_now); @@ -597,8 +586,11 @@ where turn_servers: allowed_turn_servers, next_timer_update: self.last_now, peer_socket: None, - possible_sockets: HashSet::default(), + possible_sockets: Default::default(), stats: Default::default(), + buffered_transmits: Default::default(), + buffer: Box::new([0u8; MAX_UDP_SIZE]), + intent_sent_at, signalling_completed_at: now, } } @@ -735,59 +727,28 @@ where continue; } - return match conn.tunnel.decapsulate(None, packet, buffer) { - TunnResult::Done => ControlFlow::Break(Ok(())), - TunnResult::Err(e) => ControlFlow::Break(Err(Error::Decapsulate(e))), + let handshake_complete_before_decapsulate = conn.wg_handshake_complete(); - // For WriteToTunnel{V4,V6}, boringtun returns the source IP of the packet that was tunneled to us. - // I am guessing this was done for convenience reasons. - // In our API, we parse the packets directly as an IpPacket. - // Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition. - TunnResult::WriteToTunnelV4(packet, ip) => { - conn.set_remote_from_wg_activity(local, from, relayed); + let control_flow = conn.decapsulate( + from, + local, + packet, + relayed, + buffer, + &mut self.allocations, + now, + ); - let ipv4_packet = - MutableIpv4Packet::new(packet).expect("boringtun verifies validity"); - debug_assert_eq!(ipv4_packet.get_source(), ip); + let handshake_complete_after_decapsulate = conn.wg_handshake_complete(); - ControlFlow::Continue((id, ipv4_packet.into())) - } - TunnResult::WriteToTunnelV6(packet, ip) => { - conn.set_remote_from_wg_activity(local, from, relayed); + // I can't think of a better way to detect this ... + if !handshake_complete_before_decapsulate && handshake_complete_after_decapsulate { + tracing::info!(%id, duration_since_intent = ?conn.duration_since_intent(now), "Completed wireguard handshake"); + } - let ipv6_packet = - MutableIpv6Packet::new(packet).expect("boringtun verifies validity"); - debug_assert_eq!(ipv6_packet.get_source(), ip); - - ControlFlow::Continue((id, ipv6_packet.into())) - } - - // During normal operation, i.e. when the tunnel is active, decapsulating a packet straight yields the decrypted packet. - // However, in case `Tunn` has buffered packets, they may be returned here instead. - // This should be fairly rare which is why we just allocate these and return them from `poll_transmit` instead. - // Overall, this results in a much nicer API for our caller and should not affect performance. - TunnResult::WriteToNetwork(bytes) => { - conn.set_remote_from_wg_activity(local, from, relayed); - - self.buffered_transmits.extend(conn.encapsulate( - bytes, - &mut self.allocations, - now, - )); - - while let TunnResult::WriteToNetwork(packet) = - conn.tunnel - .decapsulate(None, &[], self.buffer.as_mut_slice()) - { - self.buffered_transmits.extend(conn.encapsulate( - packet, - &mut self.allocations, - now, - )); - } - - ControlFlow::Break(Ok(())) - } + return match control_flow { + ControlFlow::Continue(c) => ControlFlow::Continue((id, c)), + ControlFlow::Break(b) => ControlFlow::Break(b), }; } @@ -812,6 +773,8 @@ where id: TId, allowed_stun_servers: HashSet, allowed_turn_servers: HashSet<(SocketAddr, String, String, String)>, + intent_sent_at: Instant, + now: Instant, ) -> Offer { if self.connections.initial.remove(&id).is_some() { tracing::info!("Replacing existing initial connection"); @@ -844,20 +807,20 @@ where }, }; - let existing = self.connections.initial.insert( - id, - InitialConnection { - agent, - session_key, - stun_servers: allowed_stun_servers, - turn_servers: allowed_turn_servers, - created_at: self.last_now, - }, - ); + let initial_connection = InitialConnection { + agent, + session_key, + stun_servers: allowed_stun_servers, + turn_servers: allowed_turn_servers, + created_at: now, + intent_sent_at, + }; + let duration_since_intent = initial_connection.duration_since_intent(now); + let existing = self.connections.initial.insert(id, initial_connection); debug_assert!(existing.is_none()); - tracing::info!("Establishing new connection"); + tracing::info!(?duration_since_intent, "Establishing new connection"); params } @@ -894,12 +857,14 @@ where *initial.session_key.expose_secret(), initial.stun_servers, initial.turn_servers, + initial.intent_sent_at, now, ); + let duration_since_intent = connection.duration_since_intent(now); let existing = self.connections.established.insert(id, connection); - tracing::info!(remote = %hex::encode(remote.as_bytes()), "Signalling protocol completed"); + tracing::info!(?duration_since_intent, remote = %hex::encode(remote.as_bytes()), "Signalling protocol completed"); debug_assert!(existing.is_none()); } @@ -968,6 +933,7 @@ where *offer.session_key.expose_secret(), allowed_stun_servers, allowed_turn_servers, + now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time. now, ); let existing = self.connections.established.insert(id, connection); @@ -1294,6 +1260,13 @@ struct InitialConnection { turn_servers: HashSet, created_at: Instant, + intent_sent_at: Instant, +} + +impl InitialConnection { + fn duration_since_intent(&self, now: Instant) -> Duration { + now.duration_since(self.intent_sent_at) + } } struct Connection { @@ -1307,11 +1280,16 @@ struct Connection { // Socket addresses from which we might receive data (even before we are connected). possible_sockets: HashSet, + buffered_transmits: VecDeque>, + stun_servers: HashSet, turn_servers: HashSet, stats: ConnectionStats, + buffer: Box<[u8; MAX_UDP_SIZE]>, + intent_sent_at: Instant, + signalling_completed_at: Instant, } @@ -1353,12 +1331,20 @@ impl Connection { from_connected_remote || from_possible_remote } + fn wg_handshake_complete(&self) -> bool { + self.tunnel.time_since_last_handshake().is_some() + } + + fn duration_since_intent(&self, now: Instant) -> Duration { + now.duration_since(self.intent_sent_at) + } + fn set_remote_from_wg_activity( &mut self, local: SocketAddr, dest: SocketAddr, relay_socket: Option, - ) { + ) -> PeerSocket { let remote_socket = match relay_socket { Some(relay_socket) => PeerSocket::Relay { relay: relay_socket.server(), @@ -1374,6 +1360,8 @@ impl Connection { tracing::debug!(old = ?self.peer_socket, new = ?remote_socket, "Updating remote socket from WG activity"); self.peer_socket = Some(remote_socket); } + + remote_socket } #[must_use] @@ -1399,6 +1387,10 @@ impl Connection { allocations: &mut HashMap, now: Instant, ) -> Option> { + if let Some(transmit) = self.buffered_transmits.pop_front() { + return Some(transmit); + } + loop { let transmit = self.agent.poll_transmit()?; let source = transmit.source; @@ -1442,7 +1434,7 @@ impl Connection { &mut self, now: Instant, allocations: &mut HashMap, - ) -> Result>, ConnectionError> { + ) -> Result<(), ConnectionError> { self.agent.handle_timeout(now); if self @@ -1458,9 +1450,9 @@ impl Connection { self.next_timer_update = now + Duration::from_secs(1); // Don't update wireguard timers until we are connected. - if self.peer_socket.is_none() { - return Ok(None); - } + let Some(peer_socket) = self.peer_socket else { + return Ok(()); + }; /// [`boringtun`] requires us to pass buffers in where it can construct its packets. /// @@ -1473,47 +1465,106 @@ impl Connection { TunnResult::Done => {} TunnResult::Err(e) => return Err(ConnectionError::Wireguard(e)), TunnResult::WriteToNetwork(b) => { - let Some(transmit) = self.encapsulate(b, allocations, now) else { - return Ok(None); - }; - - return Ok(Some(transmit.into_owned())); + self.buffered_transmits.extend(make_owned_transmit( + peer_socket, + b, + allocations, + now, + )); } _ => panic!("Unexpected result from update_timers"), }; } - Ok(None) + Ok(()) } - #[must_use] - fn encapsulate( - &self, - message: &[u8], + fn encapsulate<'b>( + &mut self, + packet: &[u8], + buffer: &'b mut [u8], + ) -> Result, Error> { + let len = match self.tunnel.encapsulate(packet, buffer) { + TunnResult::Done => return Ok(None), + TunnResult::Err(e) => return Err(Error::Encapsulate(e)), + TunnResult::WriteToNetwork(packet) => packet.len(), + TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { + unreachable!("never returned from encapsulate") + } + }; + + Ok(Some(&buffer[..len])) + } + + #[allow(clippy::too_many_arguments)] + fn decapsulate<'b>( + &mut self, + from: SocketAddr, + local: SocketAddr, + packet: &[u8], + relayed: Option, + buffer: &'b mut [u8], allocations: &mut HashMap, now: Instant, - ) -> Option> { - match self.peer_socket? { - PeerSocket::Direct { - dest: remote, - source, - } => Some(Transmit { - src: Some(source), - dst: remote, - payload: Cow::Owned(message.into()), - }), - PeerSocket::Relay { relay, dest: peer } => { - encode_as_channel_data(relay, peer, message, allocations, now).ok() + ) -> ControlFlow, MutableIpPacket<'b>> { + match self.tunnel.decapsulate(None, packet, buffer) { + TunnResult::Done => ControlFlow::Break(Ok(())), + TunnResult::Err(e) => ControlFlow::Break(Err(Error::Decapsulate(e))), + + // For WriteToTunnel{V4,V6}, boringtun returns the source IP of the packet that was tunneled to us. + // I am guessing this was done for convenience reasons. + // In our API, we parse the packets directly as an IpPacket. + // Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition. + TunnResult::WriteToTunnelV4(packet, ip) => { + self.set_remote_from_wg_activity(local, from, relayed); + + let ipv4_packet = + MutableIpv4Packet::new(packet).expect("boringtun verifies validity"); + debug_assert_eq!(ipv4_packet.get_source(), ip); + + ControlFlow::Continue(ipv4_packet.into()) + } + TunnResult::WriteToTunnelV6(packet, ip) => { + self.set_remote_from_wg_activity(local, from, relayed); + + let ipv6_packet = + MutableIpv6Packet::new(packet).expect("boringtun verifies validity"); + debug_assert_eq!(ipv6_packet.get_source(), ip); + + ControlFlow::Continue(ipv6_packet.into()) + } + + // During normal operation, i.e. when the tunnel is active, decapsulating a packet straight yields the decrypted packet. + // However, in case `Tunn` has buffered packets, they may be returned here instead. + // This should be fairly rare which is why we just allocate these and return them from `poll_transmit` instead. + // Overall, this results in a much nicer API for our caller and should not affect performance. + TunnResult::WriteToNetwork(bytes) => { + let socket = self.set_remote_from_wg_activity(local, from, relayed); + + self.buffered_transmits.extend(make_owned_transmit( + socket, + bytes, + allocations, + now, + )); + + while let TunnResult::WriteToNetwork(packet) = + self.tunnel.decapsulate(None, &[], self.buffer.as_mut()) + { + self.buffered_transmits.extend(make_owned_transmit( + socket, + packet, + allocations, + now, + )); + } + + ControlFlow::Break(Ok(())) } } } - #[must_use] - fn force_handshake( - &mut self, - allocations: &mut HashMap, - now: Instant, - ) -> Option> { + fn force_handshake(&mut self, allocations: &mut HashMap, now: Instant) { /// [`boringtun`] requires us to pass buffers in where it can construct its packets. /// /// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`]. @@ -1524,10 +1575,15 @@ impl Connection { let TunnResult::WriteToNetwork(bytes) = self.tunnel.format_handshake_initiation(&mut buf, true) else { - return None; + return; }; - self.encapsulate(bytes, allocations, now) + let socket = self + .peer_socket + .expect("cannot force handshake without socket"); + + self.buffered_transmits + .extend(make_owned_transmit(socket, bytes, allocations, now)); } /// Invalidates all local candidates with a lower or equal priority compared to the nominated one. @@ -1567,6 +1623,30 @@ impl Connection { } } +#[must_use] +fn make_owned_transmit( + socket: PeerSocket, + message: &[u8], + allocations: &mut HashMap, + now: Instant, +) -> Option> { + let transmit = match socket { + PeerSocket::Direct { + dest: remote, + source, + } => Transmit { + src: Some(source), + dst: remote, + payload: Cow::Owned(message.into()), + }, + PeerSocket::Relay { relay, dest: peer } => { + encode_as_channel_data(relay, peer, message, allocations, now).ok()? + } + }; + + Some(transmit) +} + #[derive(Debug)] enum ConnectionError { Wireguard(WireGuardError), diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index d451c8f4a..6ca606f38 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -10,12 +10,18 @@ use str0m::{net::Protocol, Candidate}; #[test] fn connection_times_out_after_20_seconds() { - let start = Instant::now(); + let (mut alice, _) = alice_and_bob(Instant::now()); - let (mut alice, _) = alice_and_bob(start); + let created_at = Instant::now(); - let _ = alice.new_connection(1, HashSet::new(), HashSet::new()); - alice.handle_timeout(start + Duration::from_secs(20)); + let _ = alice.new_connection( + 1, + HashSet::new(), + HashSet::new(), + Instant::now(), + created_at, + ); + alice.handle_timeout(created_at + Duration::from_secs(20)); assert_eq!(alice.poll_event().unwrap(), Event::ConnectionFailed(1)); } @@ -88,7 +94,13 @@ fn only_generate_candidate_event_after_answer() { Instant::now(), ); - let offer = alice.new_connection(1, HashSet::new(), HashSet::new()); + let offer = alice.new_connection( + 1, + HashSet::new(), + HashSet::new(), + Instant::now(), + Instant::now(), + ); assert_eq!( alice.poll_event(), @@ -127,6 +139,8 @@ fn second_connection_with_same_relay_reuses_allocation() { 1, HashSet::new(), HashSet::from([relay("user1", "pass1", "realm1")]), + Instant::now(), + Instant::now(), ); let transmit = alice.poll_transmit().unwrap(); @@ -137,6 +151,8 @@ fn second_connection_with_same_relay_reuses_allocation() { 2, HashSet::new(), HashSet::from([relay("user1", "pass1", "realm1")]), + Instant::now(), + Instant::now(), ); assert!(alice.poll_transmit().is_none()); @@ -150,7 +166,7 @@ fn alice_and_bob(start: Instant) -> (ClientNode, ServerNode) { } fn send_offer(alice: &mut ClientNode, bob: &mut ServerNode, now: Instant) -> Answer { - let offer = alice.new_connection(1, HashSet::new(), HashSet::new()); + let offer = alice.new_connection(1, HashSet::new(), HashSet::new(), Instant::now(), now); bob.accept_connection( 1, diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 20f2e98a9..d4546e605 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -272,10 +272,10 @@ pub struct ClientState { } #[derive(Debug, Clone, PartialEq, Eq)] -struct AwaitingConnectionDetails { - domain: Option, +pub(crate) struct AwaitingConnectionDetails { + pub domain: Option, gateways: HashSet, - last_intent_sent_at: Instant, + pub last_intent_sent_at: Instant, } impl ClientState { @@ -352,15 +352,13 @@ impl ClientState { } } - pub(crate) fn get_awaiting_connection_domain( + pub(crate) fn get_awaiting_connection( &self, resource: &ResourceId, - ) -> Result<&Option, ConnlibError> { - Ok(&self - .awaiting_connection + ) -> Result<&AwaitingConnectionDetails, ConnlibError> { + self.awaiting_connection .get(resource) - .ok_or(Error::UnexpectedConnectionDetails)? - .domain) + .ok_or(Error::UnexpectedConnectionDetails) } pub(crate) fn attempt_to_reuse_connection( @@ -373,7 +371,7 @@ impl ClientState { .get(&resource) .ok_or(Error::UnknownResource)?; - let domain = self.get_awaiting_connection_domain(&resource)?.clone(); + let domain = self.get_awaiting_connection(&resource)?.domain.clone(); if self.is_connected_to(resource, &domain) { return Err(Error::UnexpectedConnectionDetails); diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index 3e9baa5de..4f9193207 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -68,9 +68,9 @@ where return Err(Error::PendingConnection); } - let domain = self + let awaiting_connection = self .role_state - .get_awaiting_connection_domain(&resource_id)? + .get_awaiting_connection(&resource_id)? .clone(); let offer = self.connections_state.node.new_connection( @@ -81,6 +81,8 @@ where turn(&relays, |addr| { self.connections_state.sockets.can_handle(addr) }), + awaiting_connection.last_intent_sent_at, + Instant::now(), ); Ok(Request::NewConnection(RequestConnection { @@ -92,7 +94,7 @@ where username: offer.credentials.username, password: offer.credentials.password, }, - domain, + domain: awaiting_connection.domain, }, })) } diff --git a/rust/snownet-tests/src/main.rs b/rust/snownet-tests/src/main.rs index bab97a689..3a969bb31 100644 --- a/rust/snownet-tests/src/main.rs +++ b/rust/snownet-tests/src/main.rs @@ -82,6 +82,8 @@ async fn main() -> Result<()> { 1, stun_server.into_iter().collect(), turn_server.into_iter().collect(), + Instant::now(), + Instant::now(), ); redis_connection