chore(snownet): don't update remote socket from WG activity (#4615)

Resolves: #4613.
This commit is contained in:
Thomas Eizinger
2024-04-20 10:15:19 +10:00
committed by GitHub
parent c8d36a8922
commit 0f7e80642d
14 changed files with 455 additions and 183 deletions

2
rust/Cargo.lock generated
View File

@@ -5809,7 +5809,7 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "str0m"
version = "0.5.0"
source = "git+https://github.com/firezone/str0m?branch=main#aeb62dfe53270d29d2cc72b03930a462e55b2e88"
source = "git+https://github.com/firezone/str0m?branch=main#1a69339a76ea21fa526d7a90893e3549e0281e0f"
dependencies = [
"combine",
"crc",

View File

@@ -1,7 +1,7 @@
use crate::{
messages::{
BroadcastGatewayIceCandidates, Connect, ConnectionDetails, EgressMessages,
GatewayIceCandidates, IngressMessages, InitClient, ReplyMessages,
Connect, ConnectionDetails, EgressMessages, GatewayIceCandidates, GatewaysIceCandidates,
IngressMessages, InitClient, ReplyMessages,
},
PHOENIX_TOPIC,
};
@@ -99,15 +99,29 @@ where
fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) {
match event {
firezone_tunnel::ClientEvent::SignalIceCandidate {
firezone_tunnel::ClientEvent::NewIceCandidate {
conn_id: gateway,
candidate,
} => {
tracing::debug!(%gateway, %candidate, "Sending ICE candidate to gateway");
tracing::debug!(%gateway, %candidate, "Sending new ICE candidate to gateway");
self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastIceCandidates(BroadcastGatewayIceCandidates {
EgressMessages::BroadcastIceCandidates(GatewaysIceCandidates {
gateway_ids: vec![gateway],
candidates: vec![candidate],
}),
);
}
firezone_tunnel::ClientEvent::InvalidatedIceCandidate {
conn_id: gateway,
candidate,
} => {
tracing::debug!(%gateway, %candidate, "Sending invalidated ICE candidate to gateway");
self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastInvalidatedIceCandidates(GatewaysIceCandidates {
gateway_ids: vec![gateway],
candidates: vec![candidate],
}),
@@ -200,6 +214,14 @@ where
IngressMessages::ResourceDeleted(resource) => {
self.tunnel.remove_resources(&[resource]);
}
IngressMessages::InvalidateIceCandidates(GatewayIceCandidates {
gateway_id,
candidates,
}) => {
for candidate in candidates {
self.tunnel.add_ice_candidate(gateway_id, candidate)
}
}
}
}

View File

@@ -47,20 +47,19 @@ pub enum IngressMessages {
ResourceDeleted(ResourceId),
IceCandidates(GatewayIceCandidates),
InvalidateIceCandidates(GatewayIceCandidates),
ConfigChanged(ConfigUpdate),
}
/// A gateway's ice candidate message.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct BroadcastGatewayIceCandidates {
/// Gateway's id the ice candidates are meant for
pub struct GatewaysIceCandidates {
/// The list of gateway IDs these candidates will be broadcast to.
pub gateway_ids: Vec<GatewayId>,
/// Actual RTC ice candidates
pub candidates: Vec<String>,
}
/// A gateway's ice candidate message.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct GatewayIceCandidates {
/// Gateway's id the ice candidates are from
@@ -89,7 +88,10 @@ pub enum EgressMessages {
},
RequestConnection(RequestConnection),
ReuseConnection(ReuseConnection),
BroadcastIceCandidates(BroadcastGatewayIceCandidates),
/// Candidates that can be used by the addressed gateways.
BroadcastIceCandidates(GatewaysIceCandidates),
/// Candidates that should no longer be used by the addressed gateways.
BroadcastInvalidatedIceCandidates(GatewaysIceCandidates),
}
#[cfg(test)]
@@ -108,7 +110,7 @@ mod test {
let message = r#"{"topic":"client","event":"broadcast_ice_candidates","payload":{"gateway_ids":["b3d34a15-55ab-40df-994b-a838e75d65d7"],"candidates":["candidate:7031633958891736544 1 udp 50331391 35.244.108.190 53909 typ relay"]},"ref":6}"#;
let expected = PhoenixMessage::new_message(
"client",
EgressMessages::BroadcastIceCandidates(BroadcastGatewayIceCandidates {
EgressMessages::BroadcastIceCandidates(GatewaysIceCandidates {
gateway_ids: vec!["b3d34a15-55ab-40df-994b-a838e75d65d7".parse().unwrap()],
candidates: vec![
"candidate:7031633958891736544 1 udp 50331391 35.244.108.190 53909 typ relay"
@@ -123,6 +125,22 @@ mod test {
assert_eq!(ingress_message, expected);
}
#[test]
fn invalidate_ice_candidates_message() {
let msg = r#"{"event":"invalidate_ice_candidates","ref":null,"topic":"client","payload":{"candidates":["candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"],"gateway_id":"2b1524e6-239e-4570-bc73-70a188e12101"}}"#;
let expected = IngressMessages::InvalidateIceCandidates(GatewayIceCandidates {
gateway_id: "2b1524e6-239e-4570-bc73-70a188e12101".parse().unwrap(),
candidates: vec![
"candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"
.to_owned(),
],
});
let actual = serde_json::from_str::<IngressMessages>(msg).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn connection_ready_deserialization() {
let message = r#"{

View File

@@ -46,6 +46,14 @@ impl ResourceId {
#[derive(Hash, Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)]
pub struct ClientId(Uuid);
impl FromStr for ClientId {
type Err = uuid::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(ClientId(Uuid::parse_str(s)?))
}
}
impl FromStr for ResourceId {
type Err = uuid::Error;

View File

@@ -6,7 +6,6 @@ use crate::{
};
use ::backoff::backoff::Backoff;
use bytecodec::{DecodeExt as _, EncodeExt as _};
use core::fmt;
use rand::random;
use std::{
collections::{HashMap, VecDeque},
@@ -36,9 +35,7 @@ const REQUEST_TIMEOUT: Duration = Duration::from_secs(1);
///
/// Allocations have a lifetime and need to be continuously refreshed to stay active.
#[derive(Debug)]
pub struct Allocation<RId> {
id: RId,
pub struct Allocation {
server: SocketAddr,
/// If present, the last address the relay observed for us.
@@ -73,32 +70,19 @@ pub struct Allocation<RId> {
/// Note that any combination of IP versions is possible here.
/// We might have allocated an IPv6 address on a TURN server that we are talking to IPv4 and vice versa.
#[derive(Debug, Clone, Copy)]
pub struct Socket<RId> {
/// The ID of the relay.
id: RId,
pub struct Socket {
/// The address of the socket that was allocated.
address: SocketAddr,
}
impl<RId> Socket<RId>
where
RId: Copy,
{
pub fn id(&self) -> RId {
self.id
}
impl Socket {
pub fn address(&self) -> SocketAddr {
self.address
}
}
impl<RId> Allocation<RId>
where
RId: Copy + fmt::Debug,
{
impl Allocation {
pub fn new(
id: RId,
server: SocketAddr,
username: Username,
password: String,
@@ -106,7 +90,6 @@ where
now: Instant,
) -> Self {
let mut allocation = Self {
id,
server,
last_srflx_candidate: Default::default(),
ip4_allocation: Default::default(),
@@ -405,7 +388,7 @@ where
from: SocketAddr,
packet: &'p [u8],
now: Instant,
) -> Option<(SocketAddr, &'p [u8], Socket<RId>)> {
) -> Option<(SocketAddr, &'p [u8], Socket)> {
if from != self.server {
return None;
}
@@ -612,26 +595,20 @@ where
self.server
}
pub fn ip4_socket(&self) -> Option<Socket<RId>> {
pub fn ip4_socket(&self) -> Option<Socket> {
let address = self.ip4_allocation.as_ref().map(|c| c.addr())?;
debug_assert!(address.is_ipv4());
Some(Socket {
id: self.id,
address,
})
Some(Socket { address })
}
pub fn ip6_socket(&self) -> Option<Socket<RId>> {
pub fn ip6_socket(&self) -> Option<Socket> {
let address = self.ip6_allocation.as_ref().map(|c| c.addr())?;
debug_assert!(address.is_ipv6());
Some(Socket {
id: self.id,
address,
})
Some(Socket { address })
}
fn has_allocation(&self) -> bool {
@@ -1775,10 +1752,10 @@ mod tests {
let channel_bind_peer_2 = allocation.next_message().unwrap();
assert_eq!(channel_bind_peer_1.method(), CHANNEL_BIND);
assert_eq!(peer_address(&channel_bind_peer_1), PEER2_IP4);
assert_eq!(peer_address(&channel_bind_peer_1), PEER1);
assert_eq!(channel_bind_peer_2.method(), CHANNEL_BIND);
assert_eq!(peer_address(&channel_bind_peer_2), PEER1);
assert_eq!(peer_address(&channel_bind_peer_2), PEER2_IP4);
}
#[test]
@@ -2042,10 +2019,9 @@ mod tests {
message.get_attribute::<XorPeerAddress>().unwrap().address()
}
impl Allocation<u64> {
impl Allocation {
fn for_test(start: Instant) -> Self {
Allocation::new(
1,
RELAY,
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),

View File

@@ -1,5 +1,6 @@
use crate::allocation::{Allocation, Socket};
use crate::index::IndexLfsr;
use crate::ringbuffer::RingBuffer;
use crate::stats::{ConnectionStats, NodeStats};
use crate::stun_binding::StunBinding;
use crate::utils::earliest;
@@ -16,6 +17,7 @@ use secrecy::{ExposeSecret, Secret};
use std::borrow::Cow;
use std::hash::Hash;
use std::marker::PhantomData;
use std::mem;
use std::ops::ControlFlow;
use std::time::{Duration, Instant};
use std::{
@@ -87,7 +89,7 @@ pub struct Node<T, TId, RId> {
next_rate_limiter_reset: Option<Instant>,
bindings: HashMap<SocketAddr, StunBinding>,
allocations: HashMap<RId, Allocation<RId>>,
allocations: HashMap<RId, Allocation>,
connections: Connections<TId, RId>,
pending_events: VecDeque<Event<TId>>,
@@ -232,12 +234,27 @@ where
}
}
#[tracing::instrument(level = "info", skip_all, fields(%id))]
pub fn remove_remote_candidate(&mut self, id: TId, candidate: String) {
let candidate = match Candidate::from_sdp_string(&candidate) {
Ok(c) => c,
Err(e) => {
tracing::debug!("Failed to parse candidate: {e}");
return;
}
};
if let Some(agent) = self.connections.agent_mut(id) {
agent.invalidate_candidate(&candidate);
}
}
/// Attempts to find the [`Allocation`] on the same relay as the remote's candidate.
///
/// To do that, we need to check all candidates of each allocation and compare their IP.
/// The same relay might be reachable over IPv4 and IPv6.
#[must_use]
fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation<RId>> {
fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> {
self.allocations.iter_mut().find_map(|(_, allocation)| {
allocation
.current_candidates()
@@ -283,12 +300,11 @@ where
ControlFlow::Break(Err(e)) => return Err(e),
};
let (id, packet) =
match self.connections_try_handle(from, local, packet, relayed, buffer, now) {
ControlFlow::Continue(c) => c,
ControlFlow::Break(Ok(())) => return Ok(None),
ControlFlow::Break(Err(e)) => return Err(e),
};
let (id, packet) = match self.connections_try_handle(from, packet, buffer, now) {
ControlFlow::Continue(c) => c,
ControlFlow::Break(Ok(())) => return Ok(None),
ControlFlow::Break(Err(e)) => return Err(e),
};
Ok(Some((id, packet)))
}
@@ -311,7 +327,7 @@ where
.ok_or(Error::NotConnected)?;
// Must bail early if we don't have a socket yet to avoid running into WG timeouts.
let socket = conn.peer_socket.ok_or(Error::NotConnected)?;
let socket = conn.socket().ok_or(Error::NotConnected)?;
let (header, payload) = self.buffer.as_mut().split_at_mut(4);
@@ -400,7 +416,13 @@ where
self.bindings_and_allocations_drain_events();
for (id, connection) in self.connections.iter_established_mut() {
connection.handle_timeout(id, now, &mut self.allocations, &mut self.buffered_transmits);
connection.handle_timeout(
id,
now,
&mut self.allocations,
&mut self.buffered_transmits,
&mut self.pending_events,
);
}
for (id, connection) in self.connections.initial.iter_mut() {
@@ -469,7 +491,7 @@ where
self.allocations.insert(
*id,
Allocation::new(*id, *server, username, password.clone(), realm, now),
Allocation::new(*server, username, password.clone(), realm, now),
);
tracing::info!(address = %server, "Added new TURN server");
@@ -504,14 +526,15 @@ where
Some(self.rate_limiter.clone()),
),
next_timer_update: now,
peer_socket: None,
possible_sockets: Default::default(),
stats: Default::default(),
buffer: Box::new([0u8; MAX_UDP_SIZE]),
intent_sent_at,
is_failed: false,
signalling_completed_at: now,
remote_pub_key: remote,
state: ConnectionState::Connecting {
possible_sockets: HashSet::default(),
buffered: RingBuffer::new(10),
},
}
}
@@ -577,7 +600,7 @@ where
local: SocketAddr,
packet: &'p [u8],
now: Instant,
) -> ControlFlow<(), (SocketAddr, &'p [u8], Option<Socket<RId>>)> {
) -> ControlFlow<(), (SocketAddr, &'p [u8], Option<Socket>)> {
match packet.first().copied() {
// STUN method range
Some(0..=3) => {
@@ -658,26 +681,21 @@ where
fn connections_try_handle<'b>(
&mut self,
from: SocketAddr,
local: SocketAddr,
packet: &[u8],
relayed: Option<Socket<RId>>,
buffer: &'b mut [u8],
now: Instant,
) -> ControlFlow<Result<(), Error>, (TId, MutableIpPacket<'b>)> {
for (id, conn) in self.connections.iter_established_mut() {
let _span = info_span!("connection", %id).entered();
if !conn.accepts(from) {
if !conn.accepts(&from) {
continue;
}
let handshake_complete_before_decapsulate = conn.wg_handshake_complete();
let control_flow = conn.decapsulate(
from,
local,
packet,
relayed,
buffer,
&mut self.allocations,
&mut self.buffered_transmits,
@@ -727,7 +745,8 @@ where
CandidateEvent::Invalid(candidate) => {
for (id, agent) in self.connections.agents_mut() {
let _span = info_span!("connection", %id).entered();
agent.invalidate_candidate(&candidate);
remove_local_candidate(id, agent, &candidate, &mut self.pending_events);
}
}
}
@@ -965,6 +984,7 @@ impl<TId, RId> Default for Connections<TId, RId> {
impl<TId, RId> Connections<TId, RId>
where
TId: Eq + Hash + Copy + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + fmt::Debug + fmt::Display,
{
fn remove_failed(&mut self, events: &mut VecDeque<Event<TId>>) {
self.initial.retain(|id, conn| {
@@ -977,7 +997,7 @@ where
});
self.established.retain(|id, conn| {
if conn.is_failed {
if conn.is_failed() {
events.push_back(Event::ConnectionFailed(*id));
return false;
}
@@ -1033,7 +1053,7 @@ fn encode_as_channel_data<RId>(
relay: RId,
dest: SocketAddr,
contents: &[u8],
allocations: &mut HashMap<RId, Allocation<RId>>,
allocations: &mut HashMap<RId, Allocation>,
now: Instant,
) -> Result<Transmit<'static>, EncodeError>
where
@@ -1093,7 +1113,25 @@ fn add_local_candidate<TId>(
let is_new = agent.add_local_candidate(candidate.clone());
if is_new {
pending_events.push_back(Event::SignalIceCandidate {
pending_events.push_back(Event::NewIceCandidate {
connection: id,
candidate: candidate.to_sdp_string(),
})
}
}
fn remove_local_candidate<TId>(
id: TId,
agent: &mut IceAgent,
candidate: &Candidate,
pending_events: &mut VecDeque<Event<TId>>,
) where
TId: fmt::Display,
{
let was_present = agent.invalidate_candidate(candidate);
if was_present {
pending_events.push_back(Event::InvalidateIceCandidate {
connection: id,
candidate: candidate.to_sdp_string(),
})
@@ -1119,13 +1157,22 @@ pub struct Credentials {
#[derive(Debug, PartialEq, Clone)]
pub enum Event<TId> {
/// Signal the ICE candidate to the remote via the signalling channel.
/// We created a new candidate for this connection and ask to signal it to the remote party.
///
/// Candidates are in SDP format although this may change and should be considered an implementation detail of the application.
SignalIceCandidate {
NewIceCandidate {
connection: TId,
candidate: String,
},
/// We invalidated a candidate for this connection and ask to signal that to the remote party.
///
/// Candidates are in SDP format although this may change and should be considered an implementation detail of the application.
InvalidateIceCandidate {
connection: TId,
candidate: String,
},
ConnectionEstablished(TId),
/// We failed to establish a connection.
@@ -1195,24 +1242,55 @@ impl InitialConnection {
struct Connection<RId> {
agent: IceAgent,
remote_pub_key: PublicKey,
tunnel: Tunn,
remote_pub_key: PublicKey,
next_timer_update: Instant,
// When this is `Some`, we are connected.
peer_socket: Option<PeerSocket<RId>>,
// Socket addresses from which we might receive data (even before we are connected).
possible_sockets: HashSet<SocketAddr>,
state: ConnectionState<RId>,
stats: ConnectionStats,
intent_sent_at: Instant,
signalling_completed_at: Instant,
buffer: Box<[u8; MAX_UDP_SIZE]>,
intent_sent_at: Instant,
}
is_failed: bool,
enum ConnectionState<RId> {
/// We are still running ICE to figure out, which socket to use to send data.
Connecting {
/// Socket addresses from which we might receive data (even before we are connected).
possible_sockets: HashSet<SocketAddr>,
/// Packets emitted by wireguard whilst are still running ICE.
///
/// This can happen if the remote's WG session initiation arrives at our socket before we nominate it.
/// A session initiation requires a response that we must not drop, otherwise the connection setup experiences unnecessary delays.
buffered: RingBuffer<Vec<u8>>,
},
/// A socket has been nominated.
Connected {
/// Our nominated socket.
peer_socket: PeerSocket<RId>,
/// Other addresses that we might see traffic from (e.g. STUN messages during roaming).
possible_sockets: HashSet<SocketAddr>,
},
/// The connection failed in an unrecoverable way and will be GC'd.
Failed,
}
signalling_completed_at: Instant,
impl<RId> ConnectionState<RId> {
fn add_possible_socket(&mut self, socket: SocketAddr) {
let possible_sockets = match self {
ConnectionState::Connecting {
possible_sockets, ..
} => possible_sockets,
ConnectionState::Connected {
possible_sockets, ..
} => possible_sockets,
ConnectionState::Failed => return,
};
possible_sockets.insert(socket);
}
}
/// The socket of the peer we are connected to.
@@ -1237,14 +1315,24 @@ where
/// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete.
/// We already want to accept that traffic and not throw it away.
#[must_use]
fn accepts(&self, addr: SocketAddr) -> bool {
let from_connected_remote = self.peer_socket.is_some_and(|r| match r {
PeerSocket::Direct { dest, .. } => dest == addr,
PeerSocket::Relay { dest, .. } => dest == addr,
});
let from_possible_remote = self.possible_sockets.contains(&addr);
fn accepts(&self, addr: &SocketAddr) -> bool {
match &self.state {
ConnectionState::Connecting {
possible_sockets, ..
} => possible_sockets.contains(addr),
ConnectionState::Connected {
peer_socket,
possible_sockets,
} => {
let from_nominated = match peer_socket {
PeerSocket::Direct { dest, .. } => dest == addr,
PeerSocket::Relay { dest, .. } => dest == addr,
};
from_connected_remote || from_possible_remote
from_nominated || possible_sockets.contains(addr)
}
ConnectionState::Failed => false,
}
}
fn wg_handshake_complete(&self) -> bool {
@@ -1255,31 +1343,6 @@ where
now.duration_since(self.intent_sent_at)
}
fn set_remote_from_wg_activity(
&mut self,
local: SocketAddr,
dest: SocketAddr,
relay_socket: Option<Socket<RId>>,
) -> PeerSocket<RId> {
let remote_socket = match relay_socket {
Some(relay_socket) => PeerSocket::Relay {
relay: relay_socket.id(),
dest,
},
None => PeerSocket::Direct {
source: local,
dest,
},
};
if self.peer_socket != Some(remote_socket) {
tracing::debug!(old = ?self.peer_socket, new = ?remote_socket, "Updating remote socket from WG activity");
self.peer_socket = Some(remote_socket);
}
remote_socket
}
#[must_use]
fn poll_timeout(&mut self) -> Option<Instant> {
let agent_timeout = self.agent.poll_timeout();
@@ -1302,8 +1365,9 @@ where
&mut self,
id: TId,
now: Instant,
allocations: &mut HashMap<RId, Allocation<RId>>,
allocations: &mut HashMap<RId, Allocation>,
transmits: &mut VecDeque<Transmit<'static>>,
pending_events: &mut VecDeque<Event<TId>>,
) where
TId: fmt::Display + Copy,
RId: Copy + fmt::Display,
@@ -1315,7 +1379,7 @@ where
.is_some_and(|timeout| now >= timeout)
{
tracing::info!("Connection failed (no candidates received)");
self.is_failed = true;
self.state = ConnectionState::Failed;
return;
}
@@ -1325,7 +1389,7 @@ where
self.next_timer_update = now + Duration::from_secs(1);
// Don't update wireguard timers until we are connected.
let Some(peer_socket) = self.peer_socket else {
let Some(peer_socket) = self.socket() else {
return;
};
@@ -1340,7 +1404,7 @@ where
TunnResult::Done => {}
TunnResult::Err(WireGuardError::ConnectionExpired) => {
tracing::info!("Connection failed (wireguard tunnel expired)");
self.is_failed = true;
self.state = ConnectionState::Failed;
}
TunnResult::Err(e) => {
tracing::warn!(?e);
@@ -1357,11 +1421,11 @@ where
while let Some(event) = self.agent.poll_event() {
match event {
IceAgentEvent::DiscoveredRecv { source, .. } => {
self.possible_sockets.insert(source);
self.state.add_possible_socket(source);
}
IceAgentEvent::IceConnectionStateChange(IceConnectionState::Disconnected) => {
tracing::info!("Connection failed (ICE timeout)");
self.is_failed = true;
self.state = ConnectionState::Failed;
}
IceAgentEvent::NominatedSend {
destination,
@@ -1402,13 +1466,50 @@ where
}
};
if self.peer_socket != Some(remote_socket) {
tracing::info!(old = ?self.peer_socket, new = ?remote_socket, duration_since_intent = ?self.duration_since_intent(now), "Updating remote socket");
self.peer_socket = Some(remote_socket);
let old = match mem::replace(&mut self.state, ConnectionState::Failed) {
ConnectionState::Connecting {
possible_sockets,
buffered,
} => {
transmits.extend(buffered.into_iter().flat_map(|packet| {
make_owned_transmit(remote_socket, &packet, allocations, now)
}));
self.state = ConnectionState::Connected {
peer_socket: remote_socket,
possible_sockets,
};
self.invalidate_candiates(allocations);
self.force_handshake(allocations, transmits, now);
}
None
}
ConnectionState::Connected {
peer_socket,
possible_sockets,
} if peer_socket == remote_socket => {
self.state = ConnectionState::Connected {
peer_socket,
possible_sockets,
};
continue; // If we re-nominate the same socket, don't just continue. TODO: Should this be fixed upstream?
}
ConnectionState::Connected {
peer_socket,
possible_sockets,
} => {
self.state = ConnectionState::Connected {
peer_socket: remote_socket,
possible_sockets,
};
Some(peer_socket)
}
ConnectionState::Failed => continue, // Failed connections are cleaned up, don't bother handling events.
};
tracing::info!(?old, new = ?remote_socket, duration_since_intent = ?self.duration_since_intent(now), "Updating remote socket");
self.invalidate_candiates(id, allocations, pending_events);
self.force_handshake(allocations, transmits, now);
}
IceAgentEvent::IceRestart(_) | IceAgentEvent::IceConnectionStateChange(_) => {}
}
@@ -1473,12 +1574,9 @@ where
#[allow(clippy::too_many_arguments)]
fn decapsulate<'b>(
&mut self,
from: SocketAddr,
local: SocketAddr,
packet: &[u8],
relayed: Option<Socket<RId>>,
buffer: &'b mut [u8],
allocations: &mut HashMap<RId, Allocation<RId>>,
allocations: &mut HashMap<RId, Allocation>,
transmits: &mut VecDeque<Transmit<'static>>,
now: Instant,
) -> ControlFlow<Result<(), Error>, MutableIpPacket<'b>> {
@@ -1491,8 +1589,6 @@ where
// In our API, we parse the packets directly as an IpPacket.
// Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition.
TunnResult::WriteToTunnelV4(packet, ip) => {
self.set_remote_from_wg_activity(local, from, relayed);
let ipv4_packet =
MutableIpv4Packet::new(packet).expect("boringtun verifies validity");
debug_assert_eq!(ipv4_packet.get_source(), ip);
@@ -1500,8 +1596,6 @@ where
ControlFlow::Continue(ipv4_packet.into())
}
TunnResult::WriteToTunnelV6(packet, ip) => {
self.set_remote_from_wg_activity(local, from, relayed);
let ipv6_packet =
MutableIpv6Packet::new(packet).expect("boringtun verifies validity");
debug_assert_eq!(ipv6_packet.get_source(), ip);
@@ -1514,14 +1608,38 @@ where
// This should be fairly rare which is why we just allocate these and return them from `poll_transmit` instead.
// Overall, this results in a much nicer API for our caller and should not affect performance.
TunnResult::WriteToNetwork(bytes) => {
let socket = self.set_remote_from_wg_activity(local, from, relayed);
match &mut self.state {
ConnectionState::Connecting { buffered, .. } => {
tracing::debug!("No socket has been nominated yet, buffering WG packet");
transmits.extend(make_owned_transmit(socket, bytes, allocations, now));
buffered.push(bytes.to_owned());
while let TunnResult::WriteToNetwork(packet) =
self.tunnel.decapsulate(None, &[], self.buffer.as_mut())
{
transmits.extend(make_owned_transmit(socket, packet, allocations, now));
while let TunnResult::WriteToNetwork(packet) =
self.tunnel.decapsulate(None, &[], self.buffer.as_mut())
{
buffered.push(packet.to_owned());
}
}
ConnectionState::Connected { peer_socket, .. } => {
transmits.extend(make_owned_transmit(
*peer_socket,
bytes,
allocations,
now,
));
while let TunnResult::WriteToNetwork(packet) =
self.tunnel.decapsulate(None, &[], self.buffer.as_mut())
{
transmits.extend(make_owned_transmit(
*peer_socket,
packet,
allocations,
now,
));
}
}
ConnectionState::Failed => {}
}
ControlFlow::Break(Ok(()))
@@ -1531,7 +1649,7 @@ where
fn force_handshake(
&mut self,
allocations: &mut HashMap<RId, Allocation<RId>>,
allocations: &mut HashMap<RId, Allocation>,
transmits: &mut VecDeque<Transmit<'static>>,
now: Instant,
) where
@@ -1545,14 +1663,14 @@ where
let mut buf = [0u8; MAX_SCRATCH_SPACE];
let TunnResult::WriteToNetwork(bytes) =
self.tunnel.format_handshake_initiation(&mut buf, true)
self.tunnel.format_handshake_initiation(&mut buf, false)
else {
return;
};
let socket = self
.peer_socket
.expect("cannot force handshake without socket");
.socket()
.expect("cannot force handshake while not connected");
transmits.extend(make_owned_transmit(socket, bytes, allocations, now));
}
@@ -1562,14 +1680,24 @@ where
/// Each time we nominate a candidate pair, we don't really want to keep all the others active because it creates a lot of noise.
/// At the same time, we want to retain trickle ICE and allow the ICE agent to find a _better_ pair, hence we invalidate by priority.
#[tracing::instrument(level = "debug", skip_all, fields(nominated_prio))]
fn invalidate_candiates(&mut self, allocations: &HashMap<RId, Allocation<RId>>) {
let socket = match self.peer_socket {
Some(PeerSocket::Direct { source, .. }) => source,
Some(PeerSocket::Relay { relay, .. }) => match allocations.get(&relay) {
Some(alloc) => alloc.server(),
fn invalidate_candiates<TId>(
&mut self,
id: TId,
allocations: &HashMap<RId, Allocation>,
pending_events: &mut VecDeque<Event<TId>>,
) where
TId: Copy + fmt::Display,
{
let Some(socket) = self.socket() else {
return;
};
let socket = match socket {
PeerSocket::Direct { source, .. } => source,
PeerSocket::Relay { relay, .. } => match allocations.get(&relay) {
Some(r) => r.server(),
None => return,
},
None => return,
};
let Some(nominated) = self.local_candidate(socket).cloned() else {
@@ -1587,7 +1715,7 @@ where
.collect::<Vec<_>>();
for candidate in irrelevant_candidates {
self.agent.invalidate_candidate(&candidate);
remove_local_candidate(id, &mut self.agent, &candidate, pending_events)
}
}
@@ -1597,13 +1725,24 @@ where
.iter()
.find(|c| c.addr() == source)
}
fn socket(&self) -> Option<PeerSocket<RId>> {
match self.state {
ConnectionState::Connected { peer_socket, .. } => Some(peer_socket),
ConnectionState::Connecting { .. } | ConnectionState::Failed => None,
}
}
fn is_failed(&self) -> bool {
matches!(self.state, ConnectionState::Failed)
}
}
#[must_use]
fn make_owned_transmit<RId>(
socket: PeerSocket<RId>,
message: &[u8],
allocations: &mut HashMap<RId, Allocation<RId>>,
allocations: &mut HashMap<RId, Allocation>,
now: Instant,
) -> Option<Transmit<'static>>
where

View File

@@ -1,12 +1,14 @@
use std::collections::VecDeque;
#[derive(Debug)]
pub struct RingBuffer<T> {
buffer: Vec<T>,
buffer: VecDeque<T>,
}
impl<T: PartialEq> RingBuffer<T> {
pub fn new(capacity: usize) -> Self {
RingBuffer {
buffer: Vec::with_capacity(capacity),
buffer: VecDeque::with_capacity(capacity),
}
}
@@ -15,11 +17,11 @@ impl<T: PartialEq> RingBuffer<T> {
// Remove the oldest element (at the beginning) if at capacity
self.buffer.remove(0);
}
self.buffer.push(item);
self.buffer.push_back(item);
}
pub fn pop(&mut self) -> Option<T> {
self.buffer.pop()
self.buffer.pop_front()
}
pub fn clear(&mut self) {
@@ -30,9 +32,13 @@ impl<T: PartialEq> RingBuffer<T> {
self.buffer.iter()
}
pub fn into_iter(self) -> impl Iterator<Item = T> {
self.buffer.into_iter()
}
#[cfg(test)]
fn inner(&self) -> &[T] {
self.buffer.as_slice()
fn inner(&self) -> (&[T], &[T]) {
self.buffer.as_slices()
}
}
@@ -48,7 +54,7 @@ mod tests {
buffer.push(2);
buffer.push(3);
assert_eq!(buffer.inner(), &[1, 2, 3]);
assert_eq!(buffer.inner().0, &[1, 2, 3]);
}
#[test]
@@ -59,6 +65,7 @@ mod tests {
buffer.push(2);
buffer.push(3);
assert_eq!(buffer.inner(), &[2, 3]);
assert_eq!(buffer.inner().0, &[2]);
assert_eq!(buffer.inner().1, &[3]);
}
}

View File

@@ -111,8 +111,13 @@ fn reconnect_discovers_new_interface() {
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
}
// To ensure that switching networks really works, block all traffic from the old IP.
let firewall = firewall
.with_block_rule(&alice, &bob)
.with_block_rule(&bob, &alice);
alice.switch_network("10.0.0.1:80");
alice.node.reconnect(clock.now);
alice.span.in_scope(|| alice.node.reconnect(clock.now));
// Make some progress.
for _ in 0..10 {
@@ -239,7 +244,7 @@ fn only_generate_candidate_event_after_answer() {
alice.accept_answer(1, bob.public_key(), answer, Instant::now());
assert!(iter::from_fn(|| alice.poll_event()).any(|ev| ev
== Event::SignalIceCandidate {
== Event::NewIceCandidate {
connection: 1,
candidate: Candidate::host(local_candidate, Protocol::Udp)
.unwrap()
@@ -609,6 +614,13 @@ impl EitherNode {
}
}
fn remove_remote_candidate(&mut self, id: u64, candidate: String) {
match self {
EitherNode::Client(n) => n.remove_remote_candidate(id, candidate),
EitherNode::Server(n) => n.remove_remote_candidate(id, candidate),
}
}
fn add_local_host_candidate(&mut self, socket: SocketAddr) {
match self {
EitherNode::Client(n) => n.add_local_host_candidate(socket).unwrap(),
@@ -763,7 +775,7 @@ impl TestNode {
fn signalled_candidates(&self) -> impl Iterator<Item = (u64, Candidate, Instant)> + '_ {
self.events.iter().filter_map(|(e, instant)| match e {
Event::SignalIceCandidate {
Event::NewIceCandidate {
connection,
candidate,
} => Some((
@@ -771,7 +783,9 @@ impl TestNode {
Candidate::from_sdp_string(candidate).unwrap(),
*instant,
)),
Event::ConnectionEstablished(_) | Event::ConnectionFailed(_) => None,
Event::InvalidateIceCandidate { .. }
| Event::ConnectionEstablished(_)
| Event::ConnectionFailed(_) => None,
})
}
@@ -784,7 +798,8 @@ impl TestNode {
fn failed_connections(&self) -> impl Iterator<Item = (u64, Instant)> + '_ {
self.events.iter().filter_map(|(e, instant)| match e {
Event::ConnectionFailed(id) => Some((*id, *instant)),
Event::SignalIceCandidate { .. } => None,
Event::NewIceCandidate { .. } => None,
Event::InvalidateIceCandidate { .. } => None,
Event::ConnectionEstablished(_) => None,
})
}
@@ -807,12 +822,18 @@ impl TestNode {
self.events.push((v.clone(), now));
match v {
Event::SignalIceCandidate {
Event::NewIceCandidate {
connection,
candidate,
} => other
.span
.in_scope(|| other.node.add_remote_candidate(connection, candidate, now)),
Event::InvalidateIceCandidate {
connection,
candidate,
} => other
.span
.in_scope(|| other.node.remove_remote_candidate(connection, candidate)),
Event::ConnectionEstablished(_) => {}
Event::ConnectionFailed(_) => {}
};

View File

@@ -182,6 +182,12 @@ where
.add_remote_candidate(conn_id, ice_candidate, Instant::now());
}
pub fn remove_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String) {
self.role_state
.node
.remove_remote_candidate(conn_id, ice_candidate);
}
pub fn create_or_reuse_connection(
&mut self,
resource_id: ResourceId,
@@ -835,16 +841,25 @@ impl ClientState {
snownet::Event::ConnectionFailed(id) => {
self.cleanup_connected_gateway(&id);
}
snownet::Event::SignalIceCandidate {
snownet::Event::NewIceCandidate {
connection,
candidate,
} => self
.buffered_events
.push_back(ClientEvent::SignalIceCandidate {
.push_back(ClientEvent::NewIceCandidate {
conn_id: connection,
candidate,
}),
_ => {}
snownet::Event::InvalidateIceCandidate {
connection,
candidate,
} => self
.buffered_events
.push_back(ClientEvent::InvalidatedIceCandidate {
conn_id: connection,
candidate,
}),
snownet::Event::ConnectionEstablished { .. } => {}
}
}
}

View File

@@ -169,6 +169,12 @@ where
.add_remote_candidate(conn_id, ice_candidate, Instant::now());
}
pub fn remove_ice_candidate(&mut self, conn_id: ClientId, ice_candidate: String) {
self.role_state
.node
.remove_remote_candidate(conn_id, ice_candidate);
}
fn new_peer(
&mut self,
ips: Vec<IpNetwork>,
@@ -286,12 +292,22 @@ impl GatewayState {
snownet::Event::ConnectionFailed(id) => {
self.peers.remove(&id);
}
snownet::Event::SignalIceCandidate {
snownet::Event::NewIceCandidate {
connection,
candidate,
} => {
self.buffered_events
.push_back(GatewayEvent::SignalIceCandidate {
.push_back(GatewayEvent::NewIceCandidate {
conn_id: connection,
candidate,
});
}
snownet::Event::InvalidateIceCandidate {
connection,
candidate,
} => {
self.buffered_events
.push_back(GatewayEvent::InvalidIceCandidate {
conn_id: connection,
candidate,
});

View File

@@ -242,7 +242,11 @@ where
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ClientEvent {
SignalIceCandidate {
NewIceCandidate {
conn_id: GatewayId,
candidate: String,
},
InvalidatedIceCandidate {
conn_id: GatewayId,
candidate: String,
},
@@ -256,7 +260,11 @@ pub enum ClientEvent {
}
pub enum GatewayEvent {
SignalIceCandidate {
NewIceCandidate {
conn_id: ClientId,
candidate: String,
},
InvalidIceCandidate {
conn_id: ClientId,
candidate: String,
},

View File

@@ -1,6 +1,6 @@
use crate::messages::{
AllowAccess, BroadcastClientIceCandidates, ClientIceCandidates, ConnectionReady,
EgressMessages, IngressMessages, RejectAccess, RequestConnection,
AllowAccess, ClientIceCandidates, ClientsIceCandidates, ConnectionReady, EgressMessages,
IngressMessages, RejectAccess, RequestConnection,
};
use crate::CallbackHandler;
use anyhow::Result;
@@ -84,13 +84,25 @@ impl Eventloop {
fn handle_tunnel_event(&mut self, event: firezone_tunnel::GatewayEvent) {
match event {
firezone_tunnel::GatewayEvent::SignalIceCandidate {
firezone_tunnel::GatewayEvent::NewIceCandidate {
conn_id: client,
candidate,
} => {
self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastIceCandidates(BroadcastClientIceCandidates {
EgressMessages::BroadcastIceCandidates(ClientsIceCandidates {
client_ids: vec![client],
candidates: vec![candidate],
}),
);
}
firezone_tunnel::GatewayEvent::InvalidIceCandidate {
conn_id: client,
candidate,
} => {
self.portal.send(
PHOENIX_TOPIC,
EgressMessages::BroadcastInvalidatedIceCandidates(ClientsIceCandidates {
client_ids: vec![client],
candidates: vec![candidate],
}),
@@ -140,6 +152,18 @@ impl Eventloop {
self.tunnel.add_ice_candidate(client_id, candidate);
}
}
phoenix_channel::Event::InboundMessage {
msg:
IngressMessages::InvalidateIceCandidates(ClientIceCandidates {
client_id,
candidates,
}),
..
} => {
for candidate in candidates {
self.tunnel.remove_ice_candidate(client_id, candidate);
}
}
phoenix_channel::Event::InboundMessage {
msg:
IngressMessages::RejectAccess(RejectAccess {

View File

@@ -72,12 +72,13 @@ pub enum IngressMessages {
AllowAccess(AllowAccess),
RejectAccess(RejectAccess),
IceCandidates(ClientIceCandidates),
InvalidateIceCandidates(ClientIceCandidates),
Init(InitGateway),
}
/// A client's ice candidate message.
#[derive(Debug, Serialize, Clone, PartialEq, Eq)]
pub struct BroadcastClientIceCandidates {
pub struct ClientsIceCandidates {
/// Client's id the ice candidates are meant for
pub client_ids: Vec<ClientId>,
/// Actual RTC ice candidates
@@ -99,7 +100,8 @@ pub struct ClientIceCandidates {
#[serde(rename_all = "snake_case", tag = "event", content = "payload")]
pub enum EgressMessages {
ConnectionReady(ConnectionReady),
BroadcastIceCandidates(BroadcastClientIceCandidates),
BroadcastIceCandidates(ClientsIceCandidates),
BroadcastInvalidatedIceCandidates(ClientsIceCandidates),
}
#[derive(Debug, Serialize, Clone)]
@@ -170,6 +172,22 @@ mod test {
let _: PhoenixMessage<IngressMessages, ()> = serde_json::from_str(message).unwrap();
}
#[test]
fn invalidate_ice_candidates_message() {
let msg = r#"{"event":"invalidate_ice_candidates","ref":null,"topic":"gateway","payload":{"candidates":["candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"],"client_id":"2b1524e6-239e-4570-bc73-70a188e12101"}}"#;
let expected = IngressMessages::InvalidateIceCandidates(ClientIceCandidates {
client_id: "2b1524e6-239e-4570-bc73-70a188e12101".parse().unwrap(),
candidates: vec![
"candidate:7854631899965427361 1 udp 1694498559 172.28.0.100 47717 typ srflx"
.to_owned(),
],
});
let actual = serde_json::from_str::<IngressMessages>(msg).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn init_phoenix_message() {
let m = InitMessage::Init(InitGateway {

View File

@@ -383,7 +383,7 @@ impl<T> Eventloop<T> {
}
match self.pool.poll_event() {
Some(snownet::Event::SignalIceCandidate {
Some(snownet::Event::NewIceCandidate {
connection,
candidate,
}) => {
@@ -398,7 +398,7 @@ impl<T> Eventloop<T> {
Some(snownet::Event::ConnectionFailed(conn)) => {
return Poll::Ready(Ok(Event::ConnectionFailed { conn }))
}
None => {}
Some(snownet::Event::InvalidateIceCandidate { .. }) | None => {}
}
if let Poll::Ready(Some(wire::Candidate { conn, candidate })) =