diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ec0f2d006..7db900c0a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -5809,7 +5809,7 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "str0m" version = "0.5.0" -source = "git+https://github.com/firezone/str0m?branch=main#aeb62dfe53270d29d2cc72b03930a462e55b2e88" +source = "git+https://github.com/firezone/str0m?branch=main#1a69339a76ea21fa526d7a90893e3549e0281e0f" dependencies = [ "combine", "crc", diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index fc821164f..6a7b1b18c 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -1,7 +1,7 @@ use crate::{ messages::{ - BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages, - GatewayIceCandidates, IngressMessages, InitClient, ReplyMessages, + Connect, ConnectionDetails, EgressMessages, GatewayIceCandidates, GatewaysIceCandidates, + IngressMessages, InitClient, ReplyMessages, }, PHOENIX_TOPIC, }; @@ -99,15 +99,29 @@ where fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) { match event { - firezone_tunnel::ClientEvent::SignalIceCandidate { + firezone_tunnel::ClientEvent::NewIceCandidate { conn_id: gateway, candidate, } => { - tracing::debug!(%gateway, %candidate, "Sending ICE candidate to gateway"); + tracing::debug!(%gateway, %candidate, "Sending new ICE candidate to gateway"); self.portal.send( PHOENIX_TOPIC, - EgressMessages::BroadcastIceCandidates(BroadcastGatewayIceCandidates { + EgressMessages::BroadcastIceCandidates(GatewaysIceCandidates { + gateway_ids: vec![gateway], + candidates: vec![candidate], + }), + ); + } + firezone_tunnel::ClientEvent::InvalidatedIceCandidate { + conn_id: gateway, + candidate, + } => { + tracing::debug!(%gateway, %candidate, "Sending invalidated ICE candidate to gateway"); + + self.portal.send( + PHOENIX_TOPIC, + EgressMessages::BroadcastInvalidatedIceCandidates(GatewaysIceCandidates { gateway_ids: vec![gateway], candidates: vec![candidate], }), @@ -200,6 +214,14 @@ where IngressMessages::ResourceDeleted(resource) => { self.tunnel.remove_resources(&[resource]); } + IngressMessages::InvalidateIceCandidates(GatewayIceCandidates { + gateway_id, + candidates, + }) => { + for candidate in candidates { + self.tunnel.add_ice_candidate(gateway_id, candidate) + } + } } } diff --git a/rust/connlib/clients/shared/src/messages.rs b/rust/connlib/clients/shared/src/messages.rs index 426804a06..8b05b41e1 100644 --- a/rust/connlib/clients/shared/src/messages.rs +++ b/rust/connlib/clients/shared/src/messages.rs @@ -47,20 +47,19 @@ pub enum IngressMessages { ResourceDeleted(ResourceId), IceCandidates(GatewayIceCandidates), + InvalidateIceCandidates(GatewayIceCandidates), ConfigChanged(ConfigUpdate), } -/// A gateway's ice candidate message. #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct BroadcastGatewayIceCandidates { - /// Gateway's id the ice candidates are meant for +pub struct GatewaysIceCandidates { + /// The list of gateway IDs these candidates will be broadcast to. pub gateway_ids: Vec, /// Actual RTC ice candidates pub candidates: Vec, } -/// A gateway's ice candidate message. #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] pub struct GatewayIceCandidates { /// Gateway's id the ice candidates are from @@ -89,7 +88,10 @@ pub enum EgressMessages { }, RequestConnection(RequestConnection), ReuseConnection(ReuseConnection), - BroadcastIceCandidates(BroadcastGatewayIceCandidates), + /// Candidates that can be used by the addressed gateways. + BroadcastIceCandidates(GatewaysIceCandidates), + /// Candidates that should no longer be used by the addressed gateways. + BroadcastInvalidatedIceCandidates(GatewaysIceCandidates), } #[cfg(test)] @@ -108,7 +110,7 @@ mod test { let message = r#"{"topic":"client","event":"broadcast_ice_candidates","payload":{"gateway_ids":["b3d34a15-55ab-40df-994b-a838e75d65d7"],"candidates":["candidate:7031633958891736544 1 udp 50331391 35.244.108.190 53909 typ relay"]},"ref":6}"#; let expected = PhoenixMessage::new_message( "client", - EgressMessages::BroadcastIceCandidates(BroadcastGatewayIceCandidates { + EgressMessages::BroadcastIceCandidates(GatewaysIceCandidates { gateway_ids: vec!["b3d34a15-55ab-40df-994b-a838e75d65d7".parse().unwrap()], candidates: vec![ "candidate:7031633958891736544 1 udp 50331391 35.244.108.190 53909 typ relay" @@ -123,6 +125,22 @@ mod test { assert_eq!(ingress_message, expected); } + #[test] + fn invalidate_ice_candidates_message() { + let msg = r#"{"event":"invalidate_ice_candidates","ref":null,"topic":"client","payload":{"candidates":["candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"],"gateway_id":"2b1524e6-239e-4570-bc73-70a188e12101"}}"#; + let expected = IngressMessages::InvalidateIceCandidates(GatewayIceCandidates { + gateway_id: "2b1524e6-239e-4570-bc73-70a188e12101".parse().unwrap(), + candidates: vec![ + "candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx" + .to_owned(), + ], + }); + + let actual = serde_json::from_str::(msg).unwrap(); + + assert_eq!(actual, expected); + } + #[test] fn connection_ready_deserialization() { let message = r#"{ diff --git a/rust/connlib/shared/src/messages.rs b/rust/connlib/shared/src/messages.rs index d59312eef..c92993d9d 100644 --- a/rust/connlib/shared/src/messages.rs +++ b/rust/connlib/shared/src/messages.rs @@ -46,6 +46,14 @@ impl ResourceId { #[derive(Hash, Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] pub struct ClientId(Uuid); +impl FromStr for ClientId { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + Ok(ClientId(Uuid::parse_str(s)?)) + } +} + impl FromStr for ResourceId { type Err = uuid::Error; diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 3dceba1a0..16cc3fc73 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -6,7 +6,6 @@ use crate::{ }; use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; -use core::fmt; use rand::random; use std::{ collections::{HashMap, VecDeque}, @@ -36,9 +35,7 @@ const REQUEST_TIMEOUT: Duration = Duration::from_secs(1); /// /// Allocations have a lifetime and need to be continuously refreshed to stay active. #[derive(Debug)] -pub struct Allocation { - id: RId, - +pub struct Allocation { server: SocketAddr, /// If present, the last address the relay observed for us. @@ -73,32 +70,19 @@ pub struct Allocation { /// Note that any combination of IP versions is possible here. /// We might have allocated an IPv6 address on a TURN server that we are talking to IPv4 and vice versa. #[derive(Debug, Clone, Copy)] -pub struct Socket { - /// The ID of the relay. - id: RId, +pub struct Socket { /// The address of the socket that was allocated. address: SocketAddr, } -impl Socket -where - RId: Copy, -{ - pub fn id(&self) -> RId { - self.id - } - +impl Socket { pub fn address(&self) -> SocketAddr { self.address } } -impl Allocation -where - RId: Copy + fmt::Debug, -{ +impl Allocation { pub fn new( - id: RId, server: SocketAddr, username: Username, password: String, @@ -106,7 +90,6 @@ where now: Instant, ) -> Self { let mut allocation = Self { - id, server, last_srflx_candidate: Default::default(), ip4_allocation: Default::default(), @@ -405,7 +388,7 @@ where from: SocketAddr, packet: &'p [u8], now: Instant, - ) -> Option<(SocketAddr, &'p [u8], Socket)> { + ) -> Option<(SocketAddr, &'p [u8], Socket)> { if from != self.server { return None; } @@ -612,26 +595,20 @@ where self.server } - pub fn ip4_socket(&self) -> Option> { + pub fn ip4_socket(&self) -> Option { let address = self.ip4_allocation.as_ref().map(|c| c.addr())?; debug_assert!(address.is_ipv4()); - Some(Socket { - id: self.id, - address, - }) + Some(Socket { address }) } - pub fn ip6_socket(&self) -> Option> { + pub fn ip6_socket(&self) -> Option { let address = self.ip6_allocation.as_ref().map(|c| c.addr())?; debug_assert!(address.is_ipv6()); - Some(Socket { - id: self.id, - address, - }) + Some(Socket { address }) } fn has_allocation(&self) -> bool { @@ -1775,10 +1752,10 @@ mod tests { let channel_bind_peer_2 = allocation.next_message().unwrap(); assert_eq!(channel_bind_peer_1.method(), CHANNEL_BIND); - assert_eq!(peer_address(&channel_bind_peer_1), PEER2_IP4); + assert_eq!(peer_address(&channel_bind_peer_1), PEER1); assert_eq!(channel_bind_peer_2.method(), CHANNEL_BIND); - assert_eq!(peer_address(&channel_bind_peer_2), PEER1); + assert_eq!(peer_address(&channel_bind_peer_2), PEER2_IP4); } #[test] @@ -2042,10 +2019,9 @@ mod tests { message.get_attribute::().unwrap().address() } - impl Allocation { + impl Allocation { fn for_test(start: Instant) -> Self { Allocation::new( - 1, RELAY, Username::new("foobar".to_owned()).unwrap(), "baz".to_owned(), diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index ae1ed37e2..58d52d9bc 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -1,5 +1,6 @@ use crate::allocation::{Allocation, Socket}; use crate::index::IndexLfsr; +use crate::ringbuffer::RingBuffer; use crate::stats::{ConnectionStats, NodeStats}; use crate::stun_binding::StunBinding; use crate::utils::earliest; @@ -16,6 +17,7 @@ use secrecy::{ExposeSecret, Secret}; use std::borrow::Cow; use std::hash::Hash; use std::marker::PhantomData; +use std::mem; use std::ops::ControlFlow; use std::time::{Duration, Instant}; use std::{ @@ -87,7 +89,7 @@ pub struct Node { next_rate_limiter_reset: Option, bindings: HashMap, - allocations: HashMap>, + allocations: HashMap, connections: Connections, pending_events: VecDeque>, @@ -232,12 +234,27 @@ where } } + #[tracing::instrument(level = "info", skip_all, fields(%id))] + pub fn remove_remote_candidate(&mut self, id: TId, candidate: String) { + let candidate = match Candidate::from_sdp_string(&candidate) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to parse candidate: {e}"); + return; + } + }; + + if let Some(agent) = self.connections.agent_mut(id) { + agent.invalidate_candidate(&candidate); + } + } + /// Attempts to find the [`Allocation`] on the same relay as the remote's candidate. /// /// To do that, we need to check all candidates of each allocation and compare their IP. /// The same relay might be reachable over IPv4 and IPv6. #[must_use] - fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> { + fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> { self.allocations.iter_mut().find_map(|(_, allocation)| { allocation .current_candidates() @@ -283,12 +300,11 @@ where ControlFlow::Break(Err(e)) => return Err(e), }; - let (id, packet) = - match self.connections_try_handle(from, local, packet, relayed, buffer, now) { - ControlFlow::Continue(c) => c, - ControlFlow::Break(Ok(())) => return Ok(None), - ControlFlow::Break(Err(e)) => return Err(e), - }; + let (id, packet) = match self.connections_try_handle(from, packet, buffer, now) { + ControlFlow::Continue(c) => c, + ControlFlow::Break(Ok(())) => return Ok(None), + ControlFlow::Break(Err(e)) => return Err(e), + }; Ok(Some((id, packet))) } @@ -311,7 +327,7 @@ where .ok_or(Error::NotConnected)?; // Must bail early if we don't have a socket yet to avoid running into WG timeouts. - let socket = conn.peer_socket.ok_or(Error::NotConnected)?; + let socket = conn.socket().ok_or(Error::NotConnected)?; let (header, payload) = self.buffer.as_mut().split_at_mut(4); @@ -400,7 +416,13 @@ where self.bindings_and_allocations_drain_events(); for (id, connection) in self.connections.iter_established_mut() { - connection.handle_timeout(id, now, &mut self.allocations, &mut self.buffered_transmits); + connection.handle_timeout( + id, + now, + &mut self.allocations, + &mut self.buffered_transmits, + &mut self.pending_events, + ); } for (id, connection) in self.connections.initial.iter_mut() { @@ -469,7 +491,7 @@ where self.allocations.insert( *id, - Allocation::new(*id, *server, username, password.clone(), realm, now), + Allocation::new(*server, username, password.clone(), realm, now), ); tracing::info!(address = %server, "Added new TURN server"); @@ -504,14 +526,15 @@ where Some(self.rate_limiter.clone()), ), next_timer_update: now, - peer_socket: None, - possible_sockets: Default::default(), stats: Default::default(), buffer: Box::new([0u8; MAX_UDP_SIZE]), intent_sent_at, - is_failed: false, signalling_completed_at: now, remote_pub_key: remote, + state: ConnectionState::Connecting { + possible_sockets: HashSet::default(), + buffered: RingBuffer::new(10), + }, } } @@ -577,7 +600,7 @@ where local: SocketAddr, packet: &'p [u8], now: Instant, - ) -> ControlFlow<(), (SocketAddr, &'p [u8], Option>)> { + ) -> ControlFlow<(), (SocketAddr, &'p [u8], Option)> { match packet.first().copied() { // STUN method range Some(0..=3) => { @@ -658,26 +681,21 @@ where fn connections_try_handle<'b>( &mut self, from: SocketAddr, - local: SocketAddr, packet: &[u8], - relayed: Option>, buffer: &'b mut [u8], now: Instant, ) -> ControlFlow, (TId, MutableIpPacket<'b>)> { for (id, conn) in self.connections.iter_established_mut() { let _span = info_span!("connection", %id).entered(); - if !conn.accepts(from) { + if !conn.accepts(&from) { continue; } let handshake_complete_before_decapsulate = conn.wg_handshake_complete(); let control_flow = conn.decapsulate( - from, - local, packet, - relayed, buffer, &mut self.allocations, &mut self.buffered_transmits, @@ -727,7 +745,8 @@ where CandidateEvent::Invalid(candidate) => { for (id, agent) in self.connections.agents_mut() { let _span = info_span!("connection", %id).entered(); - agent.invalidate_candidate(&candidate); + + remove_local_candidate(id, agent, &candidate, &mut self.pending_events); } } } @@ -965,6 +984,7 @@ impl Default for Connections { impl Connections where TId: Eq + Hash + Copy + fmt::Display, + RId: Copy + Eq + Hash + PartialEq + fmt::Debug + fmt::Display, { fn remove_failed(&mut self, events: &mut VecDeque>) { self.initial.retain(|id, conn| { @@ -977,7 +997,7 @@ where }); self.established.retain(|id, conn| { - if conn.is_failed { + if conn.is_failed() { events.push_back(Event::ConnectionFailed(*id)); return false; } @@ -1033,7 +1053,7 @@ fn encode_as_channel_data( relay: RId, dest: SocketAddr, contents: &[u8], - allocations: &mut HashMap>, + allocations: &mut HashMap, now: Instant, ) -> Result, EncodeError> where @@ -1093,7 +1113,25 @@ fn add_local_candidate( let is_new = agent.add_local_candidate(candidate.clone()); if is_new { - pending_events.push_back(Event::SignalIceCandidate { + pending_events.push_back(Event::NewIceCandidate { + connection: id, + candidate: candidate.to_sdp_string(), + }) + } +} + +fn remove_local_candidate( + id: TId, + agent: &mut IceAgent, + candidate: &Candidate, + pending_events: &mut VecDeque>, +) where + TId: fmt::Display, +{ + let was_present = agent.invalidate_candidate(candidate); + + if was_present { + pending_events.push_back(Event::InvalidateIceCandidate { connection: id, candidate: candidate.to_sdp_string(), }) @@ -1119,13 +1157,22 @@ pub struct Credentials { #[derive(Debug, PartialEq, Clone)] pub enum Event { - /// Signal the ICE candidate to the remote via the signalling channel. + /// We created a new candidate for this connection and ask to signal it to the remote party. /// /// Candidates are in SDP format although this may change and should be considered an implementation detail of the application. - SignalIceCandidate { + NewIceCandidate { connection: TId, candidate: String, }, + + /// We invalidated a candidate for this connection and ask to signal that to the remote party. + /// + /// Candidates are in SDP format although this may change and should be considered an implementation detail of the application. + InvalidateIceCandidate { + connection: TId, + candidate: String, + }, + ConnectionEstablished(TId), /// We failed to establish a connection. @@ -1195,24 +1242,55 @@ impl InitialConnection { struct Connection { agent: IceAgent, - remote_pub_key: PublicKey, - tunnel: Tunn, + remote_pub_key: PublicKey, next_timer_update: Instant, - // When this is `Some`, we are connected. - peer_socket: Option>, - // Socket addresses from which we might receive data (even before we are connected). - possible_sockets: HashSet, + state: ConnectionState, stats: ConnectionStats, + intent_sent_at: Instant, + signalling_completed_at: Instant, buffer: Box<[u8; MAX_UDP_SIZE]>, - intent_sent_at: Instant, +} - is_failed: bool, +enum ConnectionState { + /// We are still running ICE to figure out, which socket to use to send data. + Connecting { + /// Socket addresses from which we might receive data (even before we are connected). + possible_sockets: HashSet, + /// Packets emitted by wireguard whilst are still running ICE. + /// + /// This can happen if the remote's WG session initiation arrives at our socket before we nominate it. + /// A session initiation requires a response that we must not drop, otherwise the connection setup experiences unnecessary delays. + buffered: RingBuffer>, + }, + /// A socket has been nominated. + Connected { + /// Our nominated socket. + peer_socket: PeerSocket, + /// Other addresses that we might see traffic from (e.g. STUN messages during roaming). + possible_sockets: HashSet, + }, + /// The connection failed in an unrecoverable way and will be GC'd. + Failed, +} - signalling_completed_at: Instant, +impl ConnectionState { + fn add_possible_socket(&mut self, socket: SocketAddr) { + let possible_sockets = match self { + ConnectionState::Connecting { + possible_sockets, .. + } => possible_sockets, + ConnectionState::Connected { + possible_sockets, .. + } => possible_sockets, + ConnectionState::Failed => return, + }; + + possible_sockets.insert(socket); + } } /// The socket of the peer we are connected to. @@ -1237,14 +1315,24 @@ where /// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete. /// We already want to accept that traffic and not throw it away. #[must_use] - fn accepts(&self, addr: SocketAddr) -> bool { - let from_connected_remote = self.peer_socket.is_some_and(|r| match r { - PeerSocket::Direct { dest, .. } => dest == addr, - PeerSocket::Relay { dest, .. } => dest == addr, - }); - let from_possible_remote = self.possible_sockets.contains(&addr); + fn accepts(&self, addr: &SocketAddr) -> bool { + match &self.state { + ConnectionState::Connecting { + possible_sockets, .. + } => possible_sockets.contains(addr), + ConnectionState::Connected { + peer_socket, + possible_sockets, + } => { + let from_nominated = match peer_socket { + PeerSocket::Direct { dest, .. } => dest == addr, + PeerSocket::Relay { dest, .. } => dest == addr, + }; - from_connected_remote || from_possible_remote + from_nominated || possible_sockets.contains(addr) + } + ConnectionState::Failed => false, + } } fn wg_handshake_complete(&self) -> bool { @@ -1255,31 +1343,6 @@ where now.duration_since(self.intent_sent_at) } - fn set_remote_from_wg_activity( - &mut self, - local: SocketAddr, - dest: SocketAddr, - relay_socket: Option>, - ) -> PeerSocket { - let remote_socket = match relay_socket { - Some(relay_socket) => PeerSocket::Relay { - relay: relay_socket.id(), - dest, - }, - None => PeerSocket::Direct { - source: local, - dest, - }, - }; - - if self.peer_socket != Some(remote_socket) { - tracing::debug!(old = ?self.peer_socket, new = ?remote_socket, "Updating remote socket from WG activity"); - self.peer_socket = Some(remote_socket); - } - - remote_socket - } - #[must_use] fn poll_timeout(&mut self) -> Option { let agent_timeout = self.agent.poll_timeout(); @@ -1302,8 +1365,9 @@ where &mut self, id: TId, now: Instant, - allocations: &mut HashMap>, + allocations: &mut HashMap, transmits: &mut VecDeque>, + pending_events: &mut VecDeque>, ) where TId: fmt::Display + Copy, RId: Copy + fmt::Display, @@ -1315,7 +1379,7 @@ where .is_some_and(|timeout| now >= timeout) { tracing::info!("Connection failed (no candidates received)"); - self.is_failed = true; + self.state = ConnectionState::Failed; return; } @@ -1325,7 +1389,7 @@ where self.next_timer_update = now + Duration::from_secs(1); // Don't update wireguard timers until we are connected. - let Some(peer_socket) = self.peer_socket else { + let Some(peer_socket) = self.socket() else { return; }; @@ -1340,7 +1404,7 @@ where TunnResult::Done => {} TunnResult::Err(WireGuardError::ConnectionExpired) => { tracing::info!("Connection failed (wireguard tunnel expired)"); - self.is_failed = true; + self.state = ConnectionState::Failed; } TunnResult::Err(e) => { tracing::warn!(?e); @@ -1357,11 +1421,11 @@ where while let Some(event) = self.agent.poll_event() { match event { IceAgentEvent::DiscoveredRecv { source, .. } => { - self.possible_sockets.insert(source); + self.state.add_possible_socket(source); } IceAgentEvent::IceConnectionStateChange(IceConnectionState::Disconnected) => { tracing::info!("Connection failed (ICE timeout)"); - self.is_failed = true; + self.state = ConnectionState::Failed; } IceAgentEvent::NominatedSend { destination, @@ -1402,13 +1466,50 @@ where } }; - if self.peer_socket != Some(remote_socket) { - tracing::info!(old = ?self.peer_socket, new = ?remote_socket, duration_since_intent = ?self.duration_since_intent(now), "Updating remote socket"); - self.peer_socket = Some(remote_socket); + let old = match mem::replace(&mut self.state, ConnectionState::Failed) { + ConnectionState::Connecting { + possible_sockets, + buffered, + } => { + transmits.extend(buffered.into_iter().flat_map(|packet| { + make_owned_transmit(remote_socket, &packet, allocations, now) + })); + self.state = ConnectionState::Connected { + peer_socket: remote_socket, + possible_sockets, + }; - self.invalidate_candiates(allocations); - self.force_handshake(allocations, transmits, now); - } + None + } + ConnectionState::Connected { + peer_socket, + possible_sockets, + } if peer_socket == remote_socket => { + self.state = ConnectionState::Connected { + peer_socket, + possible_sockets, + }; + + continue; // If we re-nominate the same socket, don't just continue. TODO: Should this be fixed upstream? + } + ConnectionState::Connected { + peer_socket, + possible_sockets, + } => { + self.state = ConnectionState::Connected { + peer_socket: remote_socket, + possible_sockets, + }; + + Some(peer_socket) + } + ConnectionState::Failed => continue, // Failed connections are cleaned up, don't bother handling events. + }; + + tracing::info!(?old, new = ?remote_socket, duration_since_intent = ?self.duration_since_intent(now), "Updating remote socket"); + + self.invalidate_candiates(id, allocations, pending_events); + self.force_handshake(allocations, transmits, now); } IceAgentEvent::IceRestart(_) | IceAgentEvent::IceConnectionStateChange(_) => {} } @@ -1473,12 +1574,9 @@ where #[allow(clippy::too_many_arguments)] fn decapsulate<'b>( &mut self, - from: SocketAddr, - local: SocketAddr, packet: &[u8], - relayed: Option>, buffer: &'b mut [u8], - allocations: &mut HashMap>, + allocations: &mut HashMap, transmits: &mut VecDeque>, now: Instant, ) -> ControlFlow, MutableIpPacket<'b>> { @@ -1491,8 +1589,6 @@ where // In our API, we parse the packets directly as an IpPacket. // Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition. TunnResult::WriteToTunnelV4(packet, ip) => { - self.set_remote_from_wg_activity(local, from, relayed); - let ipv4_packet = MutableIpv4Packet::new(packet).expect("boringtun verifies validity"); debug_assert_eq!(ipv4_packet.get_source(), ip); @@ -1500,8 +1596,6 @@ where ControlFlow::Continue(ipv4_packet.into()) } TunnResult::WriteToTunnelV6(packet, ip) => { - self.set_remote_from_wg_activity(local, from, relayed); - let ipv6_packet = MutableIpv6Packet::new(packet).expect("boringtun verifies validity"); debug_assert_eq!(ipv6_packet.get_source(), ip); @@ -1514,14 +1608,38 @@ where // This should be fairly rare which is why we just allocate these and return them from `poll_transmit` instead. // Overall, this results in a much nicer API for our caller and should not affect performance. TunnResult::WriteToNetwork(bytes) => { - let socket = self.set_remote_from_wg_activity(local, from, relayed); + match &mut self.state { + ConnectionState::Connecting { buffered, .. } => { + tracing::debug!("No socket has been nominated yet, buffering WG packet"); - transmits.extend(make_owned_transmit(socket, bytes, allocations, now)); + buffered.push(bytes.to_owned()); - while let TunnResult::WriteToNetwork(packet) = - self.tunnel.decapsulate(None, &[], self.buffer.as_mut()) - { - transmits.extend(make_owned_transmit(socket, packet, allocations, now)); + while let TunnResult::WriteToNetwork(packet) = + self.tunnel.decapsulate(None, &[], self.buffer.as_mut()) + { + buffered.push(packet.to_owned()); + } + } + ConnectionState::Connected { peer_socket, .. } => { + transmits.extend(make_owned_transmit( + *peer_socket, + bytes, + allocations, + now, + )); + + while let TunnResult::WriteToNetwork(packet) = + self.tunnel.decapsulate(None, &[], self.buffer.as_mut()) + { + transmits.extend(make_owned_transmit( + *peer_socket, + packet, + allocations, + now, + )); + } + } + ConnectionState::Failed => {} } ControlFlow::Break(Ok(())) @@ -1531,7 +1649,7 @@ where fn force_handshake( &mut self, - allocations: &mut HashMap>, + allocations: &mut HashMap, transmits: &mut VecDeque>, now: Instant, ) where @@ -1545,14 +1663,14 @@ where let mut buf = [0u8; MAX_SCRATCH_SPACE]; let TunnResult::WriteToNetwork(bytes) = - self.tunnel.format_handshake_initiation(&mut buf, true) + self.tunnel.format_handshake_initiation(&mut buf, false) else { return; }; let socket = self - .peer_socket - .expect("cannot force handshake without socket"); + .socket() + .expect("cannot force handshake while not connected"); transmits.extend(make_owned_transmit(socket, bytes, allocations, now)); } @@ -1562,14 +1680,24 @@ where /// Each time we nominate a candidate pair, we don't really want to keep all the others active because it creates a lot of noise. /// At the same time, we want to retain trickle ICE and allow the ICE agent to find a _better_ pair, hence we invalidate by priority. #[tracing::instrument(level = "debug", skip_all, fields(nominated_prio))] - fn invalidate_candiates(&mut self, allocations: &HashMap>) { - let socket = match self.peer_socket { - Some(PeerSocket::Direct { source, .. }) => source, - Some(PeerSocket::Relay { relay, .. }) => match allocations.get(&relay) { - Some(alloc) => alloc.server(), + fn invalidate_candiates( + &mut self, + id: TId, + allocations: &HashMap, + pending_events: &mut VecDeque>, + ) where + TId: Copy + fmt::Display, + { + let Some(socket) = self.socket() else { + return; + }; + + let socket = match socket { + PeerSocket::Direct { source, .. } => source, + PeerSocket::Relay { relay, .. } => match allocations.get(&relay) { + Some(r) => r.server(), None => return, }, - None => return, }; let Some(nominated) = self.local_candidate(socket).cloned() else { @@ -1587,7 +1715,7 @@ where .collect::>(); for candidate in irrelevant_candidates { - self.agent.invalidate_candidate(&candidate); + remove_local_candidate(id, &mut self.agent, &candidate, pending_events) } } @@ -1597,13 +1725,24 @@ where .iter() .find(|c| c.addr() == source) } + + fn socket(&self) -> Option> { + match self.state { + ConnectionState::Connected { peer_socket, .. } => Some(peer_socket), + ConnectionState::Connecting { .. } | ConnectionState::Failed => None, + } + } + + fn is_failed(&self) -> bool { + matches!(self.state, ConnectionState::Failed) + } } #[must_use] fn make_owned_transmit( socket: PeerSocket, message: &[u8], - allocations: &mut HashMap>, + allocations: &mut HashMap, now: Instant, ) -> Option> where diff --git a/rust/connlib/snownet/src/ringbuffer.rs b/rust/connlib/snownet/src/ringbuffer.rs index 9369fb3cf..d9b420281 100644 --- a/rust/connlib/snownet/src/ringbuffer.rs +++ b/rust/connlib/snownet/src/ringbuffer.rs @@ -1,12 +1,14 @@ +use std::collections::VecDeque; + #[derive(Debug)] pub struct RingBuffer { - buffer: Vec, + buffer: VecDeque, } impl RingBuffer { pub fn new(capacity: usize) -> Self { RingBuffer { - buffer: Vec::with_capacity(capacity), + buffer: VecDeque::with_capacity(capacity), } } @@ -15,11 +17,11 @@ impl RingBuffer { // Remove the oldest element (at the beginning) if at capacity self.buffer.remove(0); } - self.buffer.push(item); + self.buffer.push_back(item); } pub fn pop(&mut self) -> Option { - self.buffer.pop() + self.buffer.pop_front() } pub fn clear(&mut self) { @@ -30,9 +32,13 @@ impl RingBuffer { self.buffer.iter() } + pub fn into_iter(self) -> impl Iterator { + self.buffer.into_iter() + } + #[cfg(test)] - fn inner(&self) -> &[T] { - self.buffer.as_slice() + fn inner(&self) -> (&[T], &[T]) { + self.buffer.as_slices() } } @@ -48,7 +54,7 @@ mod tests { buffer.push(2); buffer.push(3); - assert_eq!(buffer.inner(), &[1, 2, 3]); + assert_eq!(buffer.inner().0, &[1, 2, 3]); } #[test] @@ -59,6 +65,7 @@ mod tests { buffer.push(2); buffer.push(3); - assert_eq!(buffer.inner(), &[2, 3]); + assert_eq!(buffer.inner().0, &[2]); + assert_eq!(buffer.inner().1, &[3]); } } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index d8305ee6c..3deb6039f 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -111,8 +111,13 @@ fn reconnect_discovers_new_interface() { progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); } + // To ensure that switching networks really works, block all traffic from the old IP. + let firewall = firewall + .with_block_rule(&alice, &bob) + .with_block_rule(&bob, &alice); + alice.switch_network("10.0.0.1:80"); - alice.node.reconnect(clock.now); + alice.span.in_scope(|| alice.node.reconnect(clock.now)); // Make some progress. for _ in 0..10 { @@ -239,7 +244,7 @@ fn only_generate_candidate_event_after_answer() { alice.accept_answer(1, bob.public_key(), answer, Instant::now()); assert!(iter::from_fn(|| alice.poll_event()).any(|ev| ev - == Event::SignalIceCandidate { + == Event::NewIceCandidate { connection: 1, candidate: Candidate::host(local_candidate, Protocol::Udp) .unwrap() @@ -609,6 +614,13 @@ impl EitherNode { } } + fn remove_remote_candidate(&mut self, id: u64, candidate: String) { + match self { + EitherNode::Client(n) => n.remove_remote_candidate(id, candidate), + EitherNode::Server(n) => n.remove_remote_candidate(id, candidate), + } + } + fn add_local_host_candidate(&mut self, socket: SocketAddr) { match self { EitherNode::Client(n) => n.add_local_host_candidate(socket).unwrap(), @@ -763,7 +775,7 @@ impl TestNode { fn signalled_candidates(&self) -> impl Iterator + '_ { self.events.iter().filter_map(|(e, instant)| match e { - Event::SignalIceCandidate { + Event::NewIceCandidate { connection, candidate, } => Some(( @@ -771,7 +783,9 @@ impl TestNode { Candidate::from_sdp_string(candidate).unwrap(), *instant, )), - Event::ConnectionEstablished(_) | Event::ConnectionFailed(_) => None, + Event::InvalidateIceCandidate { .. } + | Event::ConnectionEstablished(_) + | Event::ConnectionFailed(_) => None, }) } @@ -784,7 +798,8 @@ impl TestNode { fn failed_connections(&self) -> impl Iterator + '_ { self.events.iter().filter_map(|(e, instant)| match e { Event::ConnectionFailed(id) => Some((*id, *instant)), - Event::SignalIceCandidate { .. } => None, + Event::NewIceCandidate { .. } => None, + Event::InvalidateIceCandidate { .. } => None, Event::ConnectionEstablished(_) => None, }) } @@ -807,12 +822,18 @@ impl TestNode { self.events.push((v.clone(), now)); match v { - Event::SignalIceCandidate { + Event::NewIceCandidate { connection, candidate, } => other .span .in_scope(|| other.node.add_remote_candidate(connection, candidate, now)), + Event::InvalidateIceCandidate { + connection, + candidate, + } => other + .span + .in_scope(|| other.node.remove_remote_candidate(connection, candidate)), Event::ConnectionEstablished(_) => {} Event::ConnectionFailed(_) => {} }; diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 9882dcdda..2afab7826 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -182,6 +182,12 @@ where .add_remote_candidate(conn_id, ice_candidate, Instant::now()); } + pub fn remove_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String) { + self.role_state + .node + .remove_remote_candidate(conn_id, ice_candidate); + } + pub fn create_or_reuse_connection( &mut self, resource_id: ResourceId, @@ -835,16 +841,25 @@ impl ClientState { snownet::Event::ConnectionFailed(id) => { self.cleanup_connected_gateway(&id); } - snownet::Event::SignalIceCandidate { + snownet::Event::NewIceCandidate { connection, candidate, } => self .buffered_events - .push_back(ClientEvent::SignalIceCandidate { + .push_back(ClientEvent::NewIceCandidate { conn_id: connection, candidate, }), - _ => {} + snownet::Event::InvalidateIceCandidate { + connection, + candidate, + } => self + .buffered_events + .push_back(ClientEvent::InvalidatedIceCandidate { + conn_id: connection, + candidate, + }), + snownet::Event::ConnectionEstablished { .. } => {} } } } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 5999fb9ef..615cd23c5 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -169,6 +169,12 @@ where .add_remote_candidate(conn_id, ice_candidate, Instant::now()); } + pub fn remove_ice_candidate(&mut self, conn_id: ClientId, ice_candidate: String) { + self.role_state + .node + .remove_remote_candidate(conn_id, ice_candidate); + } + fn new_peer( &mut self, ips: Vec, @@ -286,12 +292,22 @@ impl GatewayState { snownet::Event::ConnectionFailed(id) => { self.peers.remove(&id); } - snownet::Event::SignalIceCandidate { + snownet::Event::NewIceCandidate { connection, candidate, } => { self.buffered_events - .push_back(GatewayEvent::SignalIceCandidate { + .push_back(GatewayEvent::NewIceCandidate { + conn_id: connection, + candidate, + }); + } + snownet::Event::InvalidateIceCandidate { + connection, + candidate, + } => { + self.buffered_events + .push_back(GatewayEvent::InvalidIceCandidate { conn_id: connection, candidate, }); diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 1db6c8f27..7248471d1 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -242,7 +242,11 @@ where #[derive(Clone, Debug, PartialEq, Eq)] pub enum ClientEvent { - SignalIceCandidate { + NewIceCandidate { + conn_id: GatewayId, + candidate: String, + }, + InvalidatedIceCandidate { conn_id: GatewayId, candidate: String, }, @@ -256,7 +260,11 @@ pub enum ClientEvent { } pub enum GatewayEvent { - SignalIceCandidate { + NewIceCandidate { + conn_id: ClientId, + candidate: String, + }, + InvalidIceCandidate { conn_id: ClientId, candidate: String, }, diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 8ea794390..3083026d7 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -1,6 +1,6 @@ use crate::messages::{ - AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady, - EgressMessages, IngressMessages, RejectAccess, RequestConnection, + AllowAccess, ClientIceCandidates, ClientsIceCandidates, ConnectionReady, EgressMessages, + IngressMessages, RejectAccess, RequestConnection, }; use crate::CallbackHandler; use anyhow::Result; @@ -84,13 +84,25 @@ impl Eventloop { fn handle_tunnel_event(&mut self, event: firezone_tunnel::GatewayEvent) { match event { - firezone_tunnel::GatewayEvent::SignalIceCandidate { + firezone_tunnel::GatewayEvent::NewIceCandidate { conn_id: client, candidate, } => { self.portal.send( PHOENIX_TOPIC, - EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates { + EgressMessages::BroadcastIceCandidates(ClientsIceCandidates { + client_ids: vec![client], + candidates: vec![candidate], + }), + ); + } + firezone_tunnel::GatewayEvent::InvalidIceCandidate { + conn_id: client, + candidate, + } => { + self.portal.send( + PHOENIX_TOPIC, + EgressMessages::BroadcastInvalidatedIceCandidates(ClientsIceCandidates { client_ids: vec![client], candidates: vec![candidate], }), @@ -140,6 +152,18 @@ impl Eventloop { self.tunnel.add_ice_candidate(client_id, candidate); } } + phoenix_channel::Event::InboundMessage { + msg: + IngressMessages::InvalidateIceCandidates(ClientIceCandidates { + client_id, + candidates, + }), + .. + } => { + for candidate in candidates { + self.tunnel.remove_ice_candidate(client_id, candidate); + } + } phoenix_channel::Event::InboundMessage { msg: IngressMessages::RejectAccess(RejectAccess { diff --git a/rust/gateway/src/messages.rs b/rust/gateway/src/messages.rs index d5c32518b..dbfe4fba2 100644 --- a/rust/gateway/src/messages.rs +++ b/rust/gateway/src/messages.rs @@ -72,12 +72,13 @@ pub enum IngressMessages { AllowAccess(AllowAccess), RejectAccess(RejectAccess), IceCandidates(ClientIceCandidates), + InvalidateIceCandidates(ClientIceCandidates), Init(InitGateway), } /// A client's ice candidate message. #[derive(Debug, Serialize, Clone, PartialEq, Eq)] -pub struct BroadcastClientIceCandidates { +pub struct ClientsIceCandidates { /// Client's id the ice candidates are meant for pub client_ids: Vec, /// Actual RTC ice candidates @@ -99,7 +100,8 @@ pub struct ClientIceCandidates { #[serde(rename_all = "snake_case", tag = "event", content = "payload")] pub enum EgressMessages { ConnectionReady(ConnectionReady), - BroadcastIceCandidates(BroadcastClientIceCandidates), + BroadcastIceCandidates(ClientsIceCandidates), + BroadcastInvalidatedIceCandidates(ClientsIceCandidates), } #[derive(Debug, Serialize, Clone)] @@ -170,6 +172,22 @@ mod test { let _: PhoenixMessage = serde_json::from_str(message).unwrap(); } + #[test] + fn invalidate_ice_candidates_message() { + let msg = r#"{"event":"invalidate_ice_candidates","ref":null,"topic":"gateway","payload":{"candidates":["candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"],"client_id":"2b1524e6-239e-4570-bc73-70a188e12101"}}"#; + let expected = IngressMessages::InvalidateIceCandidates(ClientIceCandidates { + client_id: "2b1524e6-239e-4570-bc73-70a188e12101".parse().unwrap(), + candidates: vec![ + "candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx" + .to_owned(), + ], + }); + + let actual = serde_json::from_str::(msg).unwrap(); + + assert_eq!(actual, expected); + } + #[test] fn init_phoenix_message() { let m = InitMessage::Init(InitGateway { diff --git a/rust/snownet-tests/src/main.rs b/rust/snownet-tests/src/main.rs index bb6e22221..784f92949 100644 --- a/rust/snownet-tests/src/main.rs +++ b/rust/snownet-tests/src/main.rs @@ -383,7 +383,7 @@ impl Eventloop { } match self.pool.poll_event() { - Some(snownet::Event::SignalIceCandidate { + Some(snownet::Event::NewIceCandidate { connection, candidate, }) => { @@ -398,7 +398,7 @@ impl Eventloop { Some(snownet::Event::ConnectionFailed(conn)) => { return Poll::Ready(Ok(Event::ConnectionFailed { conn })) } - None => {} + Some(snownet::Event::InvalidateIceCandidate { .. }) | None => {} } if let Poll::Ready(Some(wire::Candidate { conn, candidate })) =