feat(connlib): pick a single relay for each connection (#6060)

Currently, each connection always uses all relays. That is pretty
wasteful in terms of bandwidth usage and processing power because we
only ever need a a single relay for a connection. When we re-deploy
relays, we actively invalidate them, meaning the connection gets cut
instantly without waiting for an ICE timeout and the next packet will
establish a new one.

This is now also asserted with a dedicated transition in `tunnel_test`.

To correctly simulate this in `tunnel_test`, we always cut the
connection to all relays. This frees us from modelling `connlib`'s
internal strategy for picking a relay which keeps the reference state
simple.

Resolves: #6014.
This commit is contained in:
Thomas Eizinger
2024-07-30 04:44:40 +01:00
committed by GitHub
parent 026feefc2c
commit 0230708182
17 changed files with 266 additions and 677 deletions

2
rust/Cargo.lock generated
View File

@@ -2005,7 +2005,6 @@ dependencies = [
"proptest",
"proptest-state-machine",
"rand 0.8.5",
"rand_core 0.6.4",
"rangemap",
"secrecy",
"serde",
@@ -5586,7 +5585,6 @@ dependencies = [
"boringtun",
"bytecodec",
"bytes",
"firezone-relay",
"hex",
"hex-display",
"ip-packet",

View File

@@ -20,7 +20,6 @@ thiserror = "1"
tracing = { workspace = true }
[dev-dependencies]
firezone-relay = { workspace = true }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[lints]

View File

@@ -11,7 +11,9 @@ use core::fmt;
use ip_packet::{
ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, MutableIpPacket, Packet as _,
};
use rand::random;
use rand::rngs::StdRng;
use rand::seq::IteratorRandom;
use rand::{random, SeedableRng};
use secrecy::{ExposeSecret, Secret};
use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashSet};
@@ -96,6 +98,7 @@ pub struct Node<T, TId, RId> {
stats: NodeStats,
marker: PhantomData<T>,
rng: StdRng,
}
#[derive(thiserror::Error, Debug)]
@@ -121,9 +124,10 @@ where
TId: Eq + Hash + Copy + Ord + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
{
pub fn new(private_key: StaticSecret) -> Self {
pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self {
let public_key = &(&private_key).into();
Self {
rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key.
private_key,
public_key: *public_key,
marker: Default::default(),
@@ -228,9 +232,12 @@ where
}
};
if let Some(agent) = self.connections.agent_mut(cid) {
agent.add_remote_candidate(candidate.clone());
}
let Some(agent) = self.connections.agent_mut(cid) else {
tracing::debug!("Unknown connection");
return;
};
agent.add_remote_candidate(candidate.clone());
match candidate.kind() {
CandidateKind::Host => {
@@ -238,21 +245,22 @@ where
// They are only useful to circumvent restrictive NATs in which case we are either talking to another relay candidate or a server-reflexive address.
return;
}
CandidateKind::Relayed => {
// Optimisatically try to bind the channel only on the same relay as the remote peer.
if let Some(allocation) = self.same_relay_as_peer(&candidate) {
allocation.bind_channel(candidate.addr(), now);
return;
}
}
CandidateKind::ServerReflexive | CandidateKind::PeerReflexive => {}
CandidateKind::Relayed
| CandidateKind::ServerReflexive
| CandidateKind::PeerReflexive => {}
}
// In other cases, bind on all relays.
for allocation in self.allocations.values_mut() {
allocation.bind_channel(candidate.addr(), now);
}
let Some(rid) = self.connections.relay(cid) else {
tracing::debug!("No relay selected for connection");
return;
};
let Some(allocation) = self.allocations.get_mut(&rid) else {
tracing::debug!(%rid, "Unknown relay");
return;
};
allocation.bind_channel(candidate.addr(), now);
}
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
@@ -270,20 +278,6 @@ where
}
}
/// Attempts to find the [`Allocation`] on the same relay as the remote's candidate.
///
/// To do that, we need to check all candidates of each allocation and compare their IP.
/// The same relay might be reachable over IPv4 and IPv6.
#[must_use]
fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> {
self.allocations.iter_mut().find_map(|(_, allocation)| {
allocation
.current_candidates()
.any(|c| c.addr().ip() == candidate.addr().ip())
.then_some(allocation)
})
}
/// Decapsulate an incoming packet.
///
/// # Returns
@@ -546,6 +540,7 @@ where
mut agent: IceAgent,
remote: PublicKey,
key: [u8; 32],
relay: Option<RId>,
intent_sent_at: Instant,
now: Instant,
) -> Connection<RId> {
@@ -576,6 +571,7 @@ where
possible_sockets: BTreeSet::default(),
buffered: RingBuffer::new(10),
},
relay,
last_outgoing: now,
last_incoming: now,
}
@@ -754,13 +750,14 @@ where
fn bindings_and_allocations_drain_events(&mut self) {
let allocation_events = self
.allocations
.values_mut()
.flat_map(|allocation| allocation.poll_event());
.iter_mut()
.flat_map(|(rid, allocation)| Some((*rid, allocation.poll_event()?)));
for event in allocation_events {
for (rid, event) in allocation_events {
match event {
CandidateEvent::New(candidate) => {
add_local_candidate_to_all(
rid,
candidate,
&mut self.connections,
&mut self.pending_events,
@@ -776,6 +773,11 @@ where
}
}
}
/// Sample a relay to use for a new connection.
fn sample_relay(&mut self) -> Option<RId> {
self.allocations.keys().copied().choose(&mut self.rng)
}
}
impl<TId, RId> Node<Client, TId, RId>
@@ -819,6 +821,7 @@ where
session_key,
created_at: now,
intent_sent_at,
relay: self.sample_relay(),
is_failed: false,
};
let duration_since_intent = initial_connection.duration_since_intent(now);
@@ -850,12 +853,15 @@ where
pass: answer.credentials.password,
});
self.seed_agent_with_local_candidates(cid, &mut agent);
let selected_relay = initial.relay;
self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent);
let connection = self.init_connection(
agent,
remote,
*initial.session_key.expose_secret(),
selected_relay,
initial.intent_sent_at,
now,
);
@@ -911,12 +917,14 @@ where
},
};
self.seed_agent_with_local_candidates(cid, &mut agent);
let selected_relay = self.sample_relay();
self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent);
let connection = self.init_connection(
agent,
remote,
*offer.session_key.expose_secret(),
selected_relay,
now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time.
now,
);
@@ -935,14 +943,25 @@ where
TId: Eq + Hash + Copy + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + fmt::Debug + fmt::Display,
{
fn seed_agent_with_local_candidates(&mut self, connection: TId, agent: &mut IceAgent) {
fn seed_agent_with_local_candidates(
&mut self,
connection: TId,
selected_relay: Option<RId>,
agent: &mut IceAgent,
) {
for candidate in self.host_candidates.iter().cloned() {
add_local_candidate(connection, agent, candidate, &mut self.pending_events);
}
let Some(selected_relay) = selected_relay else {
tracing::debug!("Skipping seeding of relay candidates: No relay selected");
return;
};
for candidate in self
.allocations
.values()
.iter()
.filter_map(|(rid, allocation)| (*rid == selected_relay).then_some(allocation))
.flat_map(|allocation| allocation.current_candidates())
{
add_local_candidate(
@@ -956,7 +975,7 @@ where
}
struct Connections<TId, RId> {
initial: BTreeMap<TId, InitialConnection>,
initial: BTreeMap<TId, InitialConnection<RId>>,
established: BTreeMap<TId, Connection<RId>>,
}
@@ -1010,6 +1029,13 @@ where
maybe_initial_connection.or(maybe_established_connection)
}
fn relay(&mut self, id: TId) -> Option<RId> {
let maybe_initial_connection = self.initial.get_mut(&id).and_then(|i| i.relay);
let maybe_established_connection = self.established.get_mut(&id).and_then(|c| c.relay);
maybe_initial_connection.or(maybe_established_connection)
}
fn agents_mut(&mut self) -> impl Iterator<Item = (TId, &mut IceAgent)> {
let initial_agents = self.initial.iter_mut().map(|(id, c)| (*id, &mut c.agent));
let negotiated_agents = self
@@ -1024,7 +1050,7 @@ where
self.established.get_mut(id)
}
fn iter_initial_mut(&mut self) -> impl Iterator<Item = (TId, &mut InitialConnection)> {
fn iter_initial_mut(&mut self) -> impl Iterator<Item = (TId, &mut InitialConnection<RId>)> {
self.initial.iter_mut().map(|(id, conn)| (*id, conn))
}
@@ -1082,22 +1108,27 @@ enum EncodeError {
}
fn add_local_candidate_to_all<TId, RId>(
rid: RId,
candidate: Candidate,
connections: &mut Connections<TId, RId>,
pending_events: &mut VecDeque<Event<TId>>,
) where
TId: Copy + fmt::Display,
RId: Copy + PartialEq,
{
let initial_connections = connections
.initial
.iter_mut()
.map(|(id, c)| (*id, &mut c.agent));
.flat_map(|(id, c)| Some((*id, &mut c.agent, c.relay?)));
let established_connections = connections
.established
.iter_mut()
.map(|(id, c)| (*id, &mut c.agent));
.flat_map(|(id, c)| Some((*id, &mut c.agent, c.relay?)));
for (cid, agent) in initial_connections.chain(established_connections) {
for (cid, agent, _) in initial_connections
.chain(established_connections)
.filter(|(_, _, selected_relay)| *selected_relay == rid)
{
let _span = info_span!("connection", %cid).entered();
add_local_candidate(cid, agent, candidate.clone(), pending_events);
@@ -1244,17 +1275,22 @@ pub(crate) enum CandidateEvent {
Invalid(Candidate),
}
struct InitialConnection {
struct InitialConnection<RId> {
agent: IceAgent,
session_key: Secret<[u8; 32]>,
/// The fallback relay we sampled for this potential connection.
///
/// `None` if we don't have any relays available.
relay: Option<RId>,
created_at: Instant,
intent_sent_at: Instant,
is_failed: bool,
}
impl InitialConnection {
impl<RId> InitialConnection<RId> {
#[tracing::instrument(level = "debug", skip_all, fields(%cid))]
fn handle_timeout<TId>(&mut self, cid: TId, now: Instant)
where
@@ -1293,6 +1329,11 @@ struct Connection<RId> {
state: ConnectionState<RId>,
/// The relay we have selected for this connection.
///
/// `None` if we didn't have any relays available.
relay: Option<RId>,
stats: ConnectionStats,
intent_sent_at: Instant,
signalling_completed_at: Instant,
@@ -1308,6 +1349,7 @@ enum ConnectionState<RId> {
Connecting {
/// Socket addresses from which we might receive data (even before we are connected).
possible_sockets: BTreeSet<SocketAddr>,
/// Packets emitted by wireguard whilst are still running ICE.
///
/// This can happen if the remote's WG session initiation arrives at our socket before we nominate it.
@@ -1514,6 +1556,7 @@ where
ConnectionState::Connecting {
possible_sockets,
buffered,
..
} => {
transmits.extend(buffered.into_iter().flat_map(|packet| {
make_owned_transmit(remote_socket, &packet, allocations, now)

View File

@@ -1,83 +1,11 @@
use boringtun::x25519::StaticSecret;
use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket};
use ip_packet::*;
use rand::rngs::OsRng;
use snownet::{Answer, Client, ClientNode, Event, Node, RelaySocket, Server, ServerNode, Transmit};
use snownet::{Answer, ClientNode, Event, ServerNode};
use std::{
collections::{BTreeSet, HashSet, VecDeque},
iter,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
time::{Duration, Instant, SystemTime},
vec,
net::{IpAddr, Ipv4Addr, SocketAddr},
time::{Duration, Instant},
};
use str0m::{net::Protocol, Candidate};
use tracing::{debug_span, Span};
use tracing_subscriber::util::SubscriberInitExt;
#[test]
fn migrate_connection_to_new_relay() {
let _guard = setup_tracing();
let mut clock = Clock::new();
let (alice, bob) = alice_and_bob();
let mut relays = [(
1,
TestRelay::new(
SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3478),
debug_span!("Roger"),
),
)];
let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80").with_relays(
"alice",
BTreeSet::default(),
&mut relays,
clock.now,
);
let mut bob = TestNode::new(debug_span!("Bob"), bob, "2.2.2.2:80").with_relays(
"bob",
BTreeSet::default(),
&mut relays,
clock.now,
);
let firewall = Firewall::default()
.with_block_rule(&alice, &bob)
.with_block_rule(&bob, &alice);
handshake(&mut alice, &mut bob, &clock);
loop {
if alice.is_connected_to(&bob) && bob.is_connected_to(&alice) {
break;
}
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
}
// Swap out the relays. "Roger" is being removed (ID 1) and "Robert" is being added (ID 2).
let mut relays = [(
2,
TestRelay::new(
SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478),
debug_span!("Robert"),
),
)];
alice = alice.with_relays("alice", BTreeSet::from([1]), &mut relays, clock.now);
bob = bob.with_relays("bob", BTreeSet::from([1]), &mut relays, clock.now);
// Make some progress. (the fact that we only need 22 clock ticks means we are no relying on timeouts here (22 * 100ms = 2.2s))
for _ in 0..22 {
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
}
alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1);
bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now);
progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock);
assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1);
}
#[test]
fn connection_times_out_after_20_seconds() {
@@ -147,10 +75,16 @@ fn answer_after_stale_connection_does_not_panic() {
fn only_generate_candidate_event_after_answer() {
let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000);
let mut alice = ClientNode::<u64, u64>::new(StaticSecret::random_from_rng(rand::thread_rng()));
let mut alice = ClientNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
alice.add_local_host_candidate(local_candidate).unwrap();
let mut bob = ServerNode::<u64, u64>::new(StaticSecret::random_from_rng(rand::thread_rng()));
let mut bob = ServerNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let offer = alice.new_connection(1, Instant::now(), Instant::now());
@@ -173,17 +107,15 @@ fn only_generate_candidate_event_after_answer() {
}));
}
fn setup_tracing() -> tracing::subscriber::DefaultGuard {
tracing_subscriber::fmt()
.with_test_writer()
.with_env_filter("debug")
.finish()
.set_default()
}
fn alice_and_bob() -> (ClientNode<u64, u64>, ServerNode<u64, u64>) {
let alice = ClientNode::new(StaticSecret::random_from_rng(rand::thread_rng()));
let bob = ServerNode::new(StaticSecret::random_from_rng(rand::thread_rng()));
let alice = ClientNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let bob = ServerNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
(alice, bob)
}
@@ -207,513 +139,3 @@ fn host(socket: &str) -> String {
fn s(socket: &str) -> SocketAddr {
socket.parse().unwrap()
}
fn ip(ip: &str) -> IpAddr {
ip.parse().unwrap()
}
// Heavily inspired by https://github.com/algesten/str0m/blob/7ed5143381cf095f7074689cc254b8c9e50d25c5/src/ice/mod.rs#L547-L647.
struct TestNode<R> {
node: Node<R, u64, u64>,
transmits: VecDeque<Transmit<'static>>,
span: Span,
received_packets: Vec<IpPacket<'static>>,
/// The primary interface we use to send packets (e.g. to relays).
primary: SocketAddr,
/// All local interfaces.
local: Vec<SocketAddr>,
events: Vec<(Event<u64>, Instant)>,
buffer: Box<[u8; 10_000]>,
}
struct TestRelay {
inner: firezone_relay::Server<OsRng>,
listen_addr: RelaySocket,
span: Span,
allocations: HashSet<(AddressFamily, AllocationPort)>,
buffer: Vec<u8>,
}
#[derive(Default)]
struct Firewall {
blocked: Vec<(SocketAddr, SocketAddr)>,
}
struct Clock {
start: Instant,
now: Instant,
tick_rate: Duration,
max_time: Instant,
}
impl Clock {
fn new() -> Self {
let now = Instant::now();
let tick_rate = Duration::from_millis(100);
let one_hour = Duration::from_secs(60) * 60;
Self {
start: now,
now,
tick_rate,
max_time: now + one_hour,
}
}
fn tick(&mut self) {
self.now += self.tick_rate;
let elapsed = self.elapsed(self.start);
if elapsed.as_millis() % 60_000 == 0 {
tracing::info!("Time since start: {elapsed:?}")
}
if self.now >= self.max_time {
panic!("Time exceeded")
}
}
fn elapsed(&self, start: Instant) -> Duration {
self.now.duration_since(start)
}
}
impl Firewall {
fn with_block_rule<R1, R2>(mut self, from: &TestNode<R1>, to: &TestNode<R2>) -> Self {
self.blocked.push((from.primary, to.primary));
self
}
}
impl TestRelay {
fn new(local: impl Into<RelaySocket>, span: Span) -> Self {
let local = local.into();
let inner = firezone_relay::Server::new(to_ip_stack(local), OsRng, 3478, 49152..=65535);
Self {
inner,
listen_addr: local,
span,
allocations: HashSet::default(),
buffer: vec![0u8; (1 << 16) - 1],
}
}
fn wants(&self, dst: SocketAddr) -> bool {
let is_v4_ctrl = self
.listen_addr
.as_v4()
.is_some_and(|v4| SocketAddr::V4(*v4) == dst);
let is_v6_ctrl = self
.listen_addr
.as_v6()
.is_some_and(|v6| SocketAddr::V6(*v6) == dst);
let is_allocation = self.allocations.contains(&match dst {
SocketAddr::V4(_) => (AddressFamily::V4, AllocationPort::new(dst.port())),
SocketAddr::V6(_) => (AddressFamily::V6, AllocationPort::new(dst.port())),
});
is_v4_ctrl || is_v6_ctrl || is_allocation
}
fn matching_listen_socket(&self, other: SocketAddr) -> Option<SocketAddr> {
match other {
SocketAddr::V4(_) => Some(SocketAddr::V4(*self.listen_addr.as_v4()?)),
SocketAddr::V6(_) => Some(SocketAddr::V6(*self.listen_addr.as_v6()?)),
}
}
fn ip4(&self) -> Option<IpAddr> {
self.listen_addr.as_v4().map(|s| IpAddr::V4(*s.ip()))
}
fn ip6(&self) -> Option<IpAddr> {
self.listen_addr.as_v6().map(|s| IpAddr::V6(*s.ip()))
}
fn handle_packet<R>(
&mut self,
payload: &[u8],
sender: SocketAddr,
dst: SocketAddr,
other: &mut TestNode<R>,
now: Instant,
) {
if self.listen_addr.matches(dst) {
self.handle_client_input(payload, ClientSocket::new(sender), other, now);
return;
}
self.handle_peer_traffic(
payload,
PeerSocket::new(sender),
AllocationPort::new(dst.port()),
other,
now,
)
}
fn handle_client_input<R>(
&mut self,
payload: &[u8],
client: ClientSocket,
receiver: &mut TestNode<R>,
now: Instant,
) {
if let Some((port, peer)) = self
.span
.in_scope(|| self.inner.handle_client_input(payload, client, now))
{
let payload = &payload[4..];
// The `dst` of the relayed packet is what TURN calls a "peer".
let dst = peer.into_socket();
// The `src_ip` is the relay's IP
let src_ip = match dst {
SocketAddr::V4(_) => {
assert!(
self.allocations.contains(&(AddressFamily::V4, port)),
"IPv4 allocation to be present if we want to send to an IPv4 socket"
);
self.ip4().expect("listen on IPv4 if we have an allocation")
}
SocketAddr::V6(_) => {
assert!(
self.allocations.contains(&(AddressFamily::V6, port)),
"IPv6 allocation to be present if we want to send to an IPv6 socket"
);
self.ip6().expect("listen on IPv6 if we have an allocation")
}
};
// The `src` of the relayed packet is the relay itself _from_ the allocated port.
let src = SocketAddr::new(src_ip, port.value());
// Check if we need to relay to ourselves (from one allocation to another)
if self.wants(dst) {
// When relaying to ourselves, we become our own peer.
let peer_socket = PeerSocket::new(src);
// The allocation that the data is arriving on is the `dst`'s port.
let allocation_port = AllocationPort::new(dst.port());
self.handle_peer_traffic(payload, peer_socket, allocation_port, receiver, now);
return;
}
receiver.receive(dst, src, payload, now);
}
}
fn handle_peer_traffic<R>(
&mut self,
payload: &[u8],
peer: PeerSocket,
port: AllocationPort,
receiver: &mut TestNode<R>,
now: Instant,
) {
if let Some((client, channel)) = self
.span
.in_scope(|| self.inner.handle_peer_traffic(payload, peer, port))
{
let full_length = firezone_relay::ChannelData::encode_header_to_slice(
channel,
payload.len() as u16,
&mut self.buffer[..4],
);
self.buffer[4..full_length].copy_from_slice(payload);
let receiving_socket = client.into_socket();
let sending_socket = self.matching_listen_socket(receiving_socket).unwrap();
receiver.receive(
receiving_socket,
sending_socket,
&self.buffer[..full_length],
now,
);
}
}
fn drain_messages<R1, R2>(
&mut self,
a1: &mut TestNode<R1>,
a2: &mut TestNode<R2>,
now: Instant,
) {
while let Some(command) = self.inner.next_command() {
match command {
firezone_relay::Command::SendMessage { payload, recipient } => {
let recipient = recipient.into_socket();
let sending_socket = self.matching_listen_socket(recipient).unwrap();
if a1.local.contains(&recipient) {
a1.receive(recipient, sending_socket, &payload, now);
continue;
}
if a2.local.contains(&recipient) {
a2.receive(recipient, sending_socket, &payload, now);
continue;
}
panic!("Relay generated traffic for unknown client")
}
firezone_relay::Command::CreateAllocation { port, family } => {
self.allocations.insert((family, port));
}
firezone_relay::Command::FreeAllocation { port, family } => {
self.allocations.remove(&(family, port));
}
}
}
}
fn make_credentials(&self, username: &str) -> (String, String) {
let expiry = SystemTime::now() + Duration::from_secs(60);
let secs = expiry
.duration_since(SystemTime::UNIX_EPOCH)
.expect("expiry must be later than UNIX_EPOCH")
.as_secs();
let password =
firezone_relay::auth::generate_password(self.inner.auth_secret(), expiry, username);
(format!("{secs}:{username}"), password)
}
}
fn to_ip_stack(socket: RelaySocket) -> IpStack {
match socket {
RelaySocket::V4(v4) => IpStack::Ip4(*v4.ip()),
RelaySocket::V6(v6) => IpStack::Ip6(*v6.ip()),
RelaySocket::Dual { v4, v6 } => IpStack::Dual {
ip4: *v4.ip(),
ip6: *v6.ip(),
},
}
}
impl<R> TestNode<R> {
pub fn new(span: Span, node: Node<R, u64, u64>, primary: &str) -> Self {
let primary = primary.parse().unwrap();
TestNode {
node,
span,
received_packets: vec![],
buffer: Box::new([0u8; 10_000]),
primary,
local: vec![primary],
events: Default::default(),
transmits: Default::default(),
}
}
fn with_relays(
mut self,
username: &str,
to_remove: BTreeSet<u64>,
relays: &mut [(u64, TestRelay)],
now: Instant,
) -> Self {
let turn_servers = relays
.iter()
.map(|(idx, relay)| {
let (username, password) = relay.make_credentials(username);
(
*idx,
relay.listen_addr,
username,
password,
"firezone".to_owned(),
)
})
.collect::<BTreeSet<_>>();
self.span
.in_scope(|| self.node.update_relays(to_remove, &turn_servers, now));
self
}
fn is_connected_to<RO>(&self, other: &TestNode<RO>) -> bool {
self.node.connection_id(other.node.public_key()).is_some()
}
fn ping<RO>(&mut self, src: IpAddr, dst: IpAddr, other: &TestNode<RO>, now: Instant) {
let id = self
.node
.connection_id(other.node.public_key())
.expect("cannot ping not-connected node");
let transmit = self
.span
.in_scope(|| {
self.node.encapsulate(
id,
ip_packet::make::icmp_request_packet(src, dst, 1, 0).to_immutable(),
now,
)
})
.unwrap()
.unwrap()
.into_owned();
self.transmits.push_back(transmit);
}
fn packets_from(&self, src: IpAddr) -> impl Iterator<Item = &IpPacket<'static>> {
self.received_packets
.iter()
.filter(move |p| p.source() == src)
}
fn receive(&mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant) {
debug_assert!(self.local.contains(&local));
if let Some((_, packet)) = self
.span
.in_scope(|| {
self.node
.decapsulate(local, from, packet, now, self.buffer.as_mut())
})
.unwrap()
{
self.received_packets.push(packet.to_immutable().to_owned())
}
}
fn drain_events<RO>(&mut self, other: &mut TestNode<RO>, now: Instant) {
while let Some(v) = self.span.in_scope(|| self.node.poll_event()) {
self.events.push((v.clone(), now));
match v {
Event::NewIceCandidate {
connection,
candidate,
} => other
.span
.in_scope(|| other.node.add_remote_candidate(connection, candidate, now)),
Event::InvalidateIceCandidate {
connection,
candidate,
} => other
.span
.in_scope(|| other.node.remove_remote_candidate(connection, candidate)),
Event::ConnectionEstablished(_)
| Event::ConnectionFailed(_)
| Event::ConnectionClosed(_) => {}
};
}
}
fn drain_transmits<RO>(
&mut self,
other: &mut TestNode<RO>,
relays: &mut [(u64, TestRelay)],
firewall: &Firewall,
now: Instant,
) {
for trans in iter::from_fn(|| self.node.poll_transmit()).chain(self.transmits.drain(..)) {
let payload = &trans.payload;
let dst = trans.dst;
if let Some((_, relay)) = relays.iter_mut().find(|(_, r)| r.wants(trans.dst)) {
relay.handle_packet(payload, self.primary, dst, other, now);
continue;
}
let Some(src) = trans.src else {
tracing::debug!(target: "router", %dst, "Unknown relay");
continue;
};
// Wasn't traffic for the relay, let's check our firewall.
if firewall.blocked.contains(&(src, dst)) {
tracing::debug!(target: "firewall", %src, %dst, "Dropping packet");
continue;
}
if !other.local.contains(&dst) {
tracing::debug!(target: "router", %src, %dst, "Unknown destination");
continue;
}
// Firewall allowed traffic, let's dispatch it.
other.receive(dst, src, payload, now);
}
}
}
fn handshake(client: &mut TestNode<Client>, server: &mut TestNode<Server>, clock: &Clock) {
let offer = client
.span
.in_scope(|| client.node.new_connection(1, clock.now, clock.now));
let answer = server.span.in_scope(|| {
server
.node
.accept_connection(1, offer, client.node.public_key(), clock.now)
});
client.span.in_scope(|| {
client
.node
.accept_answer(1, server.node.public_key(), answer, clock.now)
});
}
fn progress<R1, R2>(
a1: &mut TestNode<R1>,
a2: &mut TestNode<R2>,
relays: &mut [(u64, TestRelay)],
firewall: &Firewall,
clock: &mut Clock,
) {
clock.tick();
a1.drain_events(a2, clock.now);
a2.drain_events(a1, clock.now);
a1.drain_transmits(a2, relays, firewall, clock.now);
a2.drain_transmits(a1, relays, firewall, clock.now);
for (_, relay) in relays.iter_mut() {
relay.drain_messages(a1, a2, clock.now);
}
if let Some(timeout) = a1.node.poll_timeout() {
if clock.now >= timeout {
a1.span.in_scope(|| a1.node.handle_timeout(clock.now));
}
}
if let Some(timeout) = a2.node.poll_timeout() {
if clock.now >= timeout {
a2.span.in_scope(|| a2.node.handle_timeout(clock.now));
}
}
for (_, relay) in relays {
if let Some(timeout) = relay.inner.poll_timeout() {
if clock.now >= timeout {
relay
.span
.in_scope(|| relay.inner.handle_timeout(clock.now))
}
}
}
a1.drain_events(a2, clock.now);
a2.drain_events(a1, clock.now);
}

View File

@@ -22,7 +22,7 @@ ip_network = { version = "0.4", default-features = false }
ip_network_table = { version = "0.2", default-features = false }
itertools = { version = "0.13", default-features = false, features = ["use_std"] }
proptest = { version = "1", optional = true }
rand_core = { version = "0.6", default-features = false, features = ["getrandom"] }
rand = "0.8.5"
rangemap = "1.5.1"
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }

View File

@@ -291,6 +291,7 @@ impl ClientState {
pub(crate) fn new(
private_key: impl Into<StaticSecret>,
known_hosts: HashMap<String, Vec<IpAddr>>,
seed: [u8; 32],
) -> Self {
Self {
awaiting_connection_details: Default::default(),
@@ -303,7 +304,7 @@ impl ClientState {
interface_config: Default::default(),
buffered_packets: Default::default(),
buffered_dns_queries: Default::default(),
node: ClientNode::new(private_key.into()),
node: ClientNode::new(private_key.into(), seed),
system_resolvers: Default::default(),
sites_status: Default::default(),
gateways_site: Default::default(),
@@ -1323,7 +1324,7 @@ impl IpProvider {
#[cfg(test)]
mod tests {
use super::*;
use rand_core::OsRng;
use rand::rngs::OsRng;
#[test]
fn ignores_ip4_igmp_multicast() {
@@ -1496,7 +1497,11 @@ mod tests {
impl ClientState {
pub fn for_test() -> ClientState {
ClientState::new(StaticSecret::random_from_rng(OsRng), HashMap::new())
ClientState::new(
StaticSecret::random_from_rng(OsRng),
HashMap::new(),
rand::random(),
)
}
}

View File

@@ -131,10 +131,10 @@ pub struct GatewayState {
}
impl GatewayState {
pub(crate) fn new(private_key: impl Into<StaticSecret>) -> Self {
pub(crate) fn new(private_key: impl Into<StaticSecret>, seed: [u8; 32]) -> Self {
Self {
peers: Default::default(),
node: ServerNode::new(private_key.into()),
node: ServerNode::new(private_key.into(), seed),
next_expiry_resources_check: Default::default(),
buffered_events: VecDeque::default(),
}

View File

@@ -85,7 +85,7 @@ impl ClientTunnel {
) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
role_state: ClientState::new(private_key, known_hosts),
role_state: ClientState::new(private_key, known_hosts, rand::random()),
write_buf: Box::new([0u8; DEFAULT_MTU + 16 + 20]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
@@ -183,7 +183,7 @@ impl GatewayTunnel {
pub fn new(private_key: StaticSecret) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?,
role_state: GatewayState::new(private_key),
role_state: GatewayState::new(private_key, rand::random()),
write_buf: Box::new([0u8; DEFAULT_MTU + 20 + 16]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),

View File

@@ -1,5 +1,5 @@
use super::{
composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, sim_relay::*,
composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*,
strategies::*, stub_portal::StubPortal, transition::*,
};
use crate::dns::is_subdomain;
@@ -8,11 +8,9 @@ use connlib_shared::{
client::{self, ResourceDescription},
GatewayId, RelayId,
},
proptest::*,
DomainName, StaticSecret,
};
use hickory_proto::rr::RecordType;
use prop::collection;
use proptest::{prelude::*, sample};
use proptest_state_machine::ReferenceStateMachine;
use std::{
@@ -62,7 +60,7 @@ impl ReferenceStateMachine for ReferenceState {
(
ref_client_host(Just(client_tunnel_ip4), Just(client_tunnel_ip6)),
gateways_and_portal(),
collection::btree_map(relay_id(), relay_prototype(), 1..=2),
relays(),
global_dns_records(), // Start out with a set of global DNS records so we have something to resolve outside of DNS resources.
any::<bool>(),
)
@@ -167,6 +165,10 @@ impl ReferenceStateMachine for ReferenceState {
sample::select(resource_ids).prop_map(Transition::DeactivateResource)
})
.with(1, roam_client())
.with(
1,
migrate_relays(Just(state.relays.keys().copied().collect())), // Always take down all relays because we can't know which one was sampled for the connection.
)
.with(1, Just(Transition::ReconnectPortal))
.with(1, Just(Transition::Idle))
.with_if_not_empty(
@@ -376,19 +378,33 @@ impl ReferenceStateMachine for ReferenceState {
.add_host(state.client.inner().id, &state.client));
// When roaming, we are not connected to any resource and wait for the next packet to re-establish a connection.
state.client.exec_mut(|client| {
client.connected_cidr_resources.clear();
client.connected_dns_resources.clear();
});
state.client.exec_mut(|client| client.reset_connections());
}
Transition::ReconnectPortal => {
// Reconnecting to the portal should have no noticeable impact on the data plane.
}
Transition::RelaysPresence {
disconnected,
online,
} => {
for rid in disconnected {
let disconnected_relay =
state.relays.remove(rid).expect("old host to be present");
state.network.remove_host(&disconnected_relay);
}
for (rid, online_relay) in online {
state.relays.insert(*rid, online_relay.clone());
debug_assert!(state.network.add_host(*rid, online_relay));
}
// In case we were using the relays, all connections will be cut and require us to make a new one.
if state.drop_direct_client_traffic {
state.client.exec_mut(|client| client.reset_connections());
}
}
Transition::Idle => {
state.client.exec_mut(|client| {
client.connected_cidr_resources.clear();
client.connected_dns_resources.clear();
});
state.client.exec_mut(|client| client.reset_connections());
}
};
@@ -527,6 +543,26 @@ impl ReferenceStateMachine for ReferenceState {
Transition::DeactivateResource(r) => {
state.client.inner().all_resource_ids().contains(r)
}
Transition::RelaysPresence {
disconnected,
online,
} => {
let all_old_are_present = disconnected
.iter()
.all(|rid| state.relays.contains_key(rid));
let no_new_are_present = online.keys().all(|rid| !state.relays.contains_key(rid));
let mut additional_routes = RoutingTable::default();
for (rid, relay) in online {
if !additional_routes.add_host(*rid, relay) {
return false;
}
}
let route_overlap = state.network.overlaps_with(&additional_routes);
all_old_are_present && no_new_are_present && !route_overlap
}
Transition::Idle => true,
}
}
@@ -560,7 +596,7 @@ pub(crate) fn private_key() -> impl Strategy<Value = PrivateKey> {
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct PrivateKey([u8; 32]);
pub(crate) struct PrivateKey(pub [u8; 32]);
impl From<PrivateKey> for StaticSecret {
fn from(key: PrivateKey) -> Self {

View File

@@ -255,7 +255,7 @@ impl RefClient {
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self) -> SimClient {
let mut client_state = ClientState::new(self.key, self.known_hosts);
let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed.
let _ = client_state.update_interface_config(Interface {
ipv4: self.tunnel_ip4,
ipv6: self.tunnel_ip6,
@@ -266,6 +266,11 @@ impl RefClient {
SimClient::new(self.id, client_state)
}
pub(crate) fn reset_connections(&mut self) {
self.connected_cidr_resources.clear();
self.connected_dns_resources.clear();
}
pub(crate) fn is_tunnel_ip(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(ip4) => self.tunnel_ip4 == ip4,

View File

@@ -99,7 +99,7 @@ impl RefGateway {
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self) -> SimGateway {
SimGateway::new(GatewayState::new(self.key))
SimGateway::new(GatewayState::new(self.key, self.key.0)) // Cheating a bit here by reusing the key as seed.
}
}

View File

@@ -236,6 +236,13 @@ impl RoutingTable {
pub(crate) fn host_by_ip(&self, ip: IpAddr) -> Option<HostId> {
self.routes.exact_match(ip).copied()
}
pub(crate) fn overlaps_with(&self, other: &Self) -> bool {
other
.routes
.iter()
.any(|(route, _)| self.routes.exact_match(route).is_some())
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)]

View File

@@ -196,7 +196,7 @@ impl SimRelay {
}
}
pub(crate) fn relay_prototype() -> impl Strategy<Value = Host<u64>> {
pub(crate) fn ref_relay_host() -> impl Strategy<Value = Host<u64>> {
host(
dual_ip_stack(), // For this test, our relays always run in dual-stack mode to ensure connectivity!
Just(3478),

View File

@@ -1,16 +1,18 @@
use super::{
sim_gateway::{ref_gateway_host, RefGateway},
sim_net::Host,
sim_relay::ref_relay_host,
stub_portal::StubPortal,
};
use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES};
use connlib_shared::{
messages::{
client::{ResourceDescriptionCidr, ResourceDescriptionDns, Site, SiteId},
DnsServer, GatewayId,
DnsServer, GatewayId, RelayId,
},
proptest::{
any_ip_network, cidr_resource, dns_resource, domain_label, domain_name, gateway_id, site,
any_ip_network, cidr_resource, dns_resource, domain_label, domain_name, gateway_id,
relay_id, site,
},
DomainName,
};
@@ -195,6 +197,10 @@ pub(crate) fn gateways_and_portal() -> impl Strategy<
)
}
pub(crate) fn relays() -> impl Strategy<Value = BTreeMap<RelayId, Host<u64>>> {
collection::btree_map(relay_id(), ref_relay_host(), 1..=2)
}
fn any_site(sites: HashSet<Site>) -> impl Strategy<Value = Site> {
sample::select(Vec::from_iter(sites))
}

View File

@@ -281,6 +281,53 @@ impl StateMachineTest for TunnelTest {
c.sut.set_resources(all_resources);
});
}
Transition::RelaysPresence {
disconnected,
online,
} => {
for rid in &disconnected {
let disconnected_relay =
state.relays.remove(rid).expect("old relay to be present");
state.network.remove_host(&disconnected_relay);
}
let online = online
.into_iter()
.map(|(rid, relay)| (rid, relay.map(SimRelay::new, debug_span!("relay", %rid))))
.collect::<BTreeMap<_, _>>();
for (rid, relay) in &online {
debug_assert!(state.network.add_host(*rid, relay));
}
state.client.exec_mut({
let disconnected = disconnected.clone();
let online = online.iter();
move |c| {
c.sut.update_relays(
disconnected,
BTreeSet::from_iter(map_explode(online, "client")),
now,
);
}
});
for (id, gateway) in &mut state.gateways {
gateway.exec_mut({
let disconnected = disconnected.clone();
let online = online.iter();
move |g| {
g.sut.update_relays(
disconnected,
BTreeSet::from_iter(map_explode(online, &format!("gateway_{id}"))),
now,
)
}
});
}
state.relays.extend(online);
}
Transition::Idle => {
const IDLE_DURATION: Duration = Duration::from_secs(5 * 60);
let cut_off = state.flux_capacitor.now::<Instant>() + IDLE_DURATION;

View File

@@ -1,11 +1,17 @@
use super::sim_net::{any_ip_stack, any_port};
use super::{
sim_net::{any_ip_stack, any_port, Host},
strategies::relays,
};
use connlib_shared::{
messages::{client::ResourceDescription, DnsServer, ResourceId},
messages::{client::ResourceDescription, DnsServer, RelayId, ResourceId},
DomainName,
};
use hickory_proto::rr::RecordType;
use proptest::{prelude::*, sample};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::{
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
/// The possible transitions of the state machine.
#[derive(Clone, derivative::Derivative)]
@@ -66,6 +72,12 @@ pub(crate) enum Transition {
/// Reconnect to the portal.
ReconnectPortal,
/// Simulate deployment of new relays.
RelaysPresence {
disconnected: BTreeSet<RelayId>,
online: BTreeMap<RelayId, Host<u64>>,
},
/// Idle connlib for a while, forcing connection to auto-close.
Idle,
}
@@ -171,3 +183,12 @@ pub(crate) fn roam_client() -> impl Strategy<Value = Transition> {
port,
})
}
pub(crate) fn migrate_relays(
disconnected: impl Strategy<Value = BTreeSet<RelayId>>,
) -> impl Strategy<Value = Transition> {
(disconnected, relays()).prop_map(|(disconnected, online)| Transition::RelaysPresence {
disconnected,
online,
})
}

View File

@@ -87,7 +87,7 @@ async fn main() -> Result<()> {
match role {
Role::Dialer => {
let mut node = ClientNode::<u64, u64>::new(private_key);
let mut node = ClientNode::<u64, u64>::new(private_key, rand::random());
node.update_relays(BTreeSet::new(), &relays, Instant::now());
let offer = node.new_connection(1, Instant::now(), Instant::now());
@@ -167,7 +167,7 @@ async fn main() -> Result<()> {
}
}
Role::Listener => {
let mut node = ServerNode::<u64, u64>::new(private_key);
let mut node = ServerNode::<u64, u64>::new(private_key, rand::random());
node.update_relays(BTreeSet::new(), &relays, Instant::now());
let offer = redis_connection