diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 38ca0e652..78ced0e58 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2005,7 +2005,6 @@ dependencies = [ "proptest", "proptest-state-machine", "rand 0.8.5", - "rand_core 0.6.4", "rangemap", "secrecy", "serde", @@ -5586,7 +5585,6 @@ dependencies = [ "boringtun", "bytecodec", "bytes", - "firezone-relay", "hex", "hex-display", "ip-packet", diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index 7ce8c024b..4bf229dcd 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -20,7 +20,6 @@ thiserror = "1" tracing = { workspace = true } [dev-dependencies] -firezone-relay = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter"] } [lints] diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 111c558ff..7e01c29f9 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -11,7 +11,9 @@ use core::fmt; use ip_packet::{ ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, MutableIpPacket, Packet as _, }; -use rand::random; +use rand::rngs::StdRng; +use rand::seq::IteratorRandom; +use rand::{random, SeedableRng}; use secrecy::{ExposeSecret, Secret}; use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet, HashSet}; @@ -96,6 +98,7 @@ pub struct Node { stats: NodeStats, marker: PhantomData, + rng: StdRng, } #[derive(thiserror::Error, Debug)] @@ -121,9 +124,10 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret) -> Self { + pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { let public_key = &(&private_key).into(); Self { + rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. private_key, public_key: *public_key, marker: Default::default(), @@ -228,9 +232,12 @@ where } }; - if let Some(agent) = self.connections.agent_mut(cid) { - agent.add_remote_candidate(candidate.clone()); - } + let Some(agent) = self.connections.agent_mut(cid) else { + tracing::debug!("Unknown connection"); + return; + }; + + agent.add_remote_candidate(candidate.clone()); match candidate.kind() { CandidateKind::Host => { @@ -238,21 +245,22 @@ where // They are only useful to circumvent restrictive NATs in which case we are either talking to another relay candidate or a server-reflexive address. return; } - - CandidateKind::Relayed => { - // Optimisatically try to bind the channel only on the same relay as the remote peer. - if let Some(allocation) = self.same_relay_as_peer(&candidate) { - allocation.bind_channel(candidate.addr(), now); - return; - } - } - CandidateKind::ServerReflexive | CandidateKind::PeerReflexive => {} + CandidateKind::Relayed + | CandidateKind::ServerReflexive + | CandidateKind::PeerReflexive => {} } - // In other cases, bind on all relays. - for allocation in self.allocations.values_mut() { - allocation.bind_channel(candidate.addr(), now); - } + let Some(rid) = self.connections.relay(cid) else { + tracing::debug!("No relay selected for connection"); + return; + }; + + let Some(allocation) = self.allocations.get_mut(&rid) else { + tracing::debug!(%rid, "Unknown relay"); + return; + }; + + allocation.bind_channel(candidate.addr(), now); } #[tracing::instrument(level = "info", skip_all, fields(%cid))] @@ -270,20 +278,6 @@ where } } - /// Attempts to find the [`Allocation`] on the same relay as the remote's candidate. - /// - /// To do that, we need to check all candidates of each allocation and compare their IP. - /// The same relay might be reachable over IPv4 and IPv6. - #[must_use] - fn same_relay_as_peer(&mut self, candidate: &Candidate) -> Option<&mut Allocation> { - self.allocations.iter_mut().find_map(|(_, allocation)| { - allocation - .current_candidates() - .any(|c| c.addr().ip() == candidate.addr().ip()) - .then_some(allocation) - }) - } - /// Decapsulate an incoming packet. /// /// # Returns @@ -546,6 +540,7 @@ where mut agent: IceAgent, remote: PublicKey, key: [u8; 32], + relay: Option, intent_sent_at: Instant, now: Instant, ) -> Connection { @@ -576,6 +571,7 @@ where possible_sockets: BTreeSet::default(), buffered: RingBuffer::new(10), }, + relay, last_outgoing: now, last_incoming: now, } @@ -754,13 +750,14 @@ where fn bindings_and_allocations_drain_events(&mut self) { let allocation_events = self .allocations - .values_mut() - .flat_map(|allocation| allocation.poll_event()); + .iter_mut() + .flat_map(|(rid, allocation)| Some((*rid, allocation.poll_event()?))); - for event in allocation_events { + for (rid, event) in allocation_events { match event { CandidateEvent::New(candidate) => { add_local_candidate_to_all( + rid, candidate, &mut self.connections, &mut self.pending_events, @@ -776,6 +773,11 @@ where } } } + + /// Sample a relay to use for a new connection. + fn sample_relay(&mut self) -> Option { + self.allocations.keys().copied().choose(&mut self.rng) + } } impl Node @@ -819,6 +821,7 @@ where session_key, created_at: now, intent_sent_at, + relay: self.sample_relay(), is_failed: false, }; let duration_since_intent = initial_connection.duration_since_intent(now); @@ -850,12 +853,15 @@ where pass: answer.credentials.password, }); - self.seed_agent_with_local_candidates(cid, &mut agent); + let selected_relay = initial.relay; + + self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent); let connection = self.init_connection( agent, remote, *initial.session_key.expose_secret(), + selected_relay, initial.intent_sent_at, now, ); @@ -911,12 +917,14 @@ where }, }; - self.seed_agent_with_local_candidates(cid, &mut agent); + let selected_relay = self.sample_relay(); + self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent); let connection = self.init_connection( agent, remote, *offer.session_key.expose_secret(), + selected_relay, now, // Technically, this isn't fully correct because gateways don't send intents so we just use the current time. now, ); @@ -935,14 +943,25 @@ where TId: Eq + Hash + Copy + fmt::Display, RId: Copy + Eq + Hash + PartialEq + fmt::Debug + fmt::Display, { - fn seed_agent_with_local_candidates(&mut self, connection: TId, agent: &mut IceAgent) { + fn seed_agent_with_local_candidates( + &mut self, + connection: TId, + selected_relay: Option, + agent: &mut IceAgent, + ) { for candidate in self.host_candidates.iter().cloned() { add_local_candidate(connection, agent, candidate, &mut self.pending_events); } + let Some(selected_relay) = selected_relay else { + tracing::debug!("Skipping seeding of relay candidates: No relay selected"); + return; + }; + for candidate in self .allocations - .values() + .iter() + .filter_map(|(rid, allocation)| (*rid == selected_relay).then_some(allocation)) .flat_map(|allocation| allocation.current_candidates()) { add_local_candidate( @@ -956,7 +975,7 @@ where } struct Connections { - initial: BTreeMap, + initial: BTreeMap>, established: BTreeMap>, } @@ -1010,6 +1029,13 @@ where maybe_initial_connection.or(maybe_established_connection) } + fn relay(&mut self, id: TId) -> Option { + let maybe_initial_connection = self.initial.get_mut(&id).and_then(|i| i.relay); + let maybe_established_connection = self.established.get_mut(&id).and_then(|c| c.relay); + + maybe_initial_connection.or(maybe_established_connection) + } + fn agents_mut(&mut self) -> impl Iterator { let initial_agents = self.initial.iter_mut().map(|(id, c)| (*id, &mut c.agent)); let negotiated_agents = self @@ -1024,7 +1050,7 @@ where self.established.get_mut(id) } - fn iter_initial_mut(&mut self) -> impl Iterator { + fn iter_initial_mut(&mut self) -> impl Iterator)> { self.initial.iter_mut().map(|(id, conn)| (*id, conn)) } @@ -1082,22 +1108,27 @@ enum EncodeError { } fn add_local_candidate_to_all( + rid: RId, candidate: Candidate, connections: &mut Connections, pending_events: &mut VecDeque>, ) where TId: Copy + fmt::Display, + RId: Copy + PartialEq, { let initial_connections = connections .initial .iter_mut() - .map(|(id, c)| (*id, &mut c.agent)); + .flat_map(|(id, c)| Some((*id, &mut c.agent, c.relay?))); let established_connections = connections .established .iter_mut() - .map(|(id, c)| (*id, &mut c.agent)); + .flat_map(|(id, c)| Some((*id, &mut c.agent, c.relay?))); - for (cid, agent) in initial_connections.chain(established_connections) { + for (cid, agent, _) in initial_connections + .chain(established_connections) + .filter(|(_, _, selected_relay)| *selected_relay == rid) + { let _span = info_span!("connection", %cid).entered(); add_local_candidate(cid, agent, candidate.clone(), pending_events); @@ -1244,17 +1275,22 @@ pub(crate) enum CandidateEvent { Invalid(Candidate), } -struct InitialConnection { +struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>, + /// The fallback relay we sampled for this potential connection. + /// + /// `None` if we don't have any relays available. + relay: Option, + created_at: Instant, intent_sent_at: Instant, is_failed: bool, } -impl InitialConnection { +impl InitialConnection { #[tracing::instrument(level = "debug", skip_all, fields(%cid))] fn handle_timeout(&mut self, cid: TId, now: Instant) where @@ -1293,6 +1329,11 @@ struct Connection { state: ConnectionState, + /// The relay we have selected for this connection. + /// + /// `None` if we didn't have any relays available. + relay: Option, + stats: ConnectionStats, intent_sent_at: Instant, signalling_completed_at: Instant, @@ -1308,6 +1349,7 @@ enum ConnectionState { Connecting { /// Socket addresses from which we might receive data (even before we are connected). possible_sockets: BTreeSet, + /// Packets emitted by wireguard whilst are still running ICE. /// /// This can happen if the remote's WG session initiation arrives at our socket before we nominate it. @@ -1514,6 +1556,7 @@ where ConnectionState::Connecting { possible_sockets, buffered, + .. } => { transmits.extend(buffered.into_iter().flat_map(|packet| { make_owned_transmit(remote_socket, &packet, allocations, now) diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 69b47b1b5..cbc878089 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,83 +1,11 @@ use boringtun::x25519::StaticSecret; -use firezone_relay::{AddressFamily, AllocationPort, ClientSocket, IpStack, PeerSocket}; -use ip_packet::*; -use rand::rngs::OsRng; -use snownet::{Answer, Client, ClientNode, Event, Node, RelaySocket, Server, ServerNode, Transmit}; +use snownet::{Answer, ClientNode, Event, ServerNode}; use std::{ - collections::{BTreeSet, HashSet, VecDeque}, iter, - net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, - time::{Duration, Instant, SystemTime}, - vec, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::{Duration, Instant}, }; use str0m::{net::Protocol, Candidate}; -use tracing::{debug_span, Span}; -use tracing_subscriber::util::SubscriberInitExt; - -#[test] -fn migrate_connection_to_new_relay() { - let _guard = setup_tracing(); - let mut clock = Clock::new(); - - let (alice, bob) = alice_and_bob(); - - let mut relays = [( - 1, - TestRelay::new( - SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3478), - debug_span!("Roger"), - ), - )]; - let mut alice = TestNode::new(debug_span!("Alice"), alice, "1.1.1.1:80").with_relays( - "alice", - BTreeSet::default(), - &mut relays, - clock.now, - ); - let mut bob = TestNode::new(debug_span!("Bob"), bob, "2.2.2.2:80").with_relays( - "bob", - BTreeSet::default(), - &mut relays, - clock.now, - ); - let firewall = Firewall::default() - .with_block_rule(&alice, &bob) - .with_block_rule(&bob, &alice); - - handshake(&mut alice, &mut bob, &clock); - - loop { - if alice.is_connected_to(&bob) && bob.is_connected_to(&alice) { - break; - } - - progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); - } - - // Swap out the relays. "Roger" is being removed (ID 1) and "Robert" is being added (ID 2). - let mut relays = [( - 2, - TestRelay::new( - SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478), - debug_span!("Robert"), - ), - )]; - alice = alice.with_relays("alice", BTreeSet::from([1]), &mut relays, clock.now); - bob = bob.with_relays("bob", BTreeSet::from([1]), &mut relays, clock.now); - - // Make some progress. (the fact that we only need 22 clock ticks means we are no relying on timeouts here (22 * 100ms = 2.2s)) - for _ in 0..22 { - progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); - } - - alice.ping(ip("9.9.9.9"), ip("8.8.8.8"), &bob, clock.now); - progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); - assert_eq!(bob.packets_from(ip("9.9.9.9")).count(), 1); - - bob.ping(ip("8.8.8.8"), ip("9.9.9.9"), &alice, clock.now); - progress(&mut alice, &mut bob, &mut relays, &firewall, &mut clock); - assert_eq!(alice.packets_from(ip("8.8.8.8")).count(), 1); -} #[test] fn connection_times_out_after_20_seconds() { @@ -147,10 +75,16 @@ fn answer_after_stale_connection_does_not_panic() { fn only_generate_candidate_event_after_answer() { let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000); - let mut alice = ClientNode::::new(StaticSecret::random_from_rng(rand::thread_rng())); + let mut alice = ClientNode::::new( + StaticSecret::random_from_rng(rand::thread_rng()), + rand::random(), + ); alice.add_local_host_candidate(local_candidate).unwrap(); - let mut bob = ServerNode::::new(StaticSecret::random_from_rng(rand::thread_rng())); + let mut bob = ServerNode::::new( + StaticSecret::random_from_rng(rand::thread_rng()), + rand::random(), + ); let offer = alice.new_connection(1, Instant::now(), Instant::now()); @@ -173,17 +107,15 @@ fn only_generate_candidate_event_after_answer() { })); } -fn setup_tracing() -> tracing::subscriber::DefaultGuard { - tracing_subscriber::fmt() - .with_test_writer() - .with_env_filter("debug") - .finish() - .set_default() -} - fn alice_and_bob() -> (ClientNode, ServerNode) { - let alice = ClientNode::new(StaticSecret::random_from_rng(rand::thread_rng())); - let bob = ServerNode::new(StaticSecret::random_from_rng(rand::thread_rng())); + let alice = ClientNode::new( + StaticSecret::random_from_rng(rand::thread_rng()), + rand::random(), + ); + let bob = ServerNode::new( + StaticSecret::random_from_rng(rand::thread_rng()), + rand::random(), + ); (alice, bob) } @@ -207,513 +139,3 @@ fn host(socket: &str) -> String { fn s(socket: &str) -> SocketAddr { socket.parse().unwrap() } - -fn ip(ip: &str) -> IpAddr { - ip.parse().unwrap() -} - -// Heavily inspired by https://github.com/algesten/str0m/blob/7ed5143381cf095f7074689cc254b8c9e50d25c5/src/ice/mod.rs#L547-L647. -struct TestNode { - node: Node, - transmits: VecDeque>, - - span: Span, - received_packets: Vec>, - /// The primary interface we use to send packets (e.g. to relays). - primary: SocketAddr, - /// All local interfaces. - local: Vec, - events: Vec<(Event, Instant)>, - - buffer: Box<[u8; 10_000]>, -} - -struct TestRelay { - inner: firezone_relay::Server, - listen_addr: RelaySocket, - span: Span, - - allocations: HashSet<(AddressFamily, AllocationPort)>, - buffer: Vec, -} - -#[derive(Default)] -struct Firewall { - blocked: Vec<(SocketAddr, SocketAddr)>, -} - -struct Clock { - start: Instant, - now: Instant, - - tick_rate: Duration, - max_time: Instant, -} - -impl Clock { - fn new() -> Self { - let now = Instant::now(); - let tick_rate = Duration::from_millis(100); - let one_hour = Duration::from_secs(60) * 60; - - Self { - start: now, - now, - tick_rate, - max_time: now + one_hour, - } - } - - fn tick(&mut self) { - self.now += self.tick_rate; - - let elapsed = self.elapsed(self.start); - - if elapsed.as_millis() % 60_000 == 0 { - tracing::info!("Time since start: {elapsed:?}") - } - - if self.now >= self.max_time { - panic!("Time exceeded") - } - } - - fn elapsed(&self, start: Instant) -> Duration { - self.now.duration_since(start) - } -} - -impl Firewall { - fn with_block_rule(mut self, from: &TestNode, to: &TestNode) -> Self { - self.blocked.push((from.primary, to.primary)); - - self - } -} - -impl TestRelay { - fn new(local: impl Into, span: Span) -> Self { - let local = local.into(); - let inner = firezone_relay::Server::new(to_ip_stack(local), OsRng, 3478, 49152..=65535); - - Self { - inner, - listen_addr: local, - span, - allocations: HashSet::default(), - buffer: vec![0u8; (1 << 16) - 1], - } - } - - fn wants(&self, dst: SocketAddr) -> bool { - let is_v4_ctrl = self - .listen_addr - .as_v4() - .is_some_and(|v4| SocketAddr::V4(*v4) == dst); - let is_v6_ctrl = self - .listen_addr - .as_v6() - .is_some_and(|v6| SocketAddr::V6(*v6) == dst); - let is_allocation = self.allocations.contains(&match dst { - SocketAddr::V4(_) => (AddressFamily::V4, AllocationPort::new(dst.port())), - SocketAddr::V6(_) => (AddressFamily::V6, AllocationPort::new(dst.port())), - }); - - is_v4_ctrl || is_v6_ctrl || is_allocation - } - - fn matching_listen_socket(&self, other: SocketAddr) -> Option { - match other { - SocketAddr::V4(_) => Some(SocketAddr::V4(*self.listen_addr.as_v4()?)), - SocketAddr::V6(_) => Some(SocketAddr::V6(*self.listen_addr.as_v6()?)), - } - } - - fn ip4(&self) -> Option { - self.listen_addr.as_v4().map(|s| IpAddr::V4(*s.ip())) - } - - fn ip6(&self) -> Option { - self.listen_addr.as_v6().map(|s| IpAddr::V6(*s.ip())) - } - - fn handle_packet( - &mut self, - payload: &[u8], - sender: SocketAddr, - dst: SocketAddr, - other: &mut TestNode, - now: Instant, - ) { - if self.listen_addr.matches(dst) { - self.handle_client_input(payload, ClientSocket::new(sender), other, now); - return; - } - - self.handle_peer_traffic( - payload, - PeerSocket::new(sender), - AllocationPort::new(dst.port()), - other, - now, - ) - } - - fn handle_client_input( - &mut self, - payload: &[u8], - client: ClientSocket, - receiver: &mut TestNode, - now: Instant, - ) { - if let Some((port, peer)) = self - .span - .in_scope(|| self.inner.handle_client_input(payload, client, now)) - { - let payload = &payload[4..]; - - // The `dst` of the relayed packet is what TURN calls a "peer". - let dst = peer.into_socket(); - - // The `src_ip` is the relay's IP - let src_ip = match dst { - SocketAddr::V4(_) => { - assert!( - self.allocations.contains(&(AddressFamily::V4, port)), - "IPv4 allocation to be present if we want to send to an IPv4 socket" - ); - - self.ip4().expect("listen on IPv4 if we have an allocation") - } - SocketAddr::V6(_) => { - assert!( - self.allocations.contains(&(AddressFamily::V6, port)), - "IPv6 allocation to be present if we want to send to an IPv6 socket" - ); - - self.ip6().expect("listen on IPv6 if we have an allocation") - } - }; - - // The `src` of the relayed packet is the relay itself _from_ the allocated port. - let src = SocketAddr::new(src_ip, port.value()); - - // Check if we need to relay to ourselves (from one allocation to another) - if self.wants(dst) { - // When relaying to ourselves, we become our own peer. - let peer_socket = PeerSocket::new(src); - // The allocation that the data is arriving on is the `dst`'s port. - let allocation_port = AllocationPort::new(dst.port()); - - self.handle_peer_traffic(payload, peer_socket, allocation_port, receiver, now); - - return; - } - - receiver.receive(dst, src, payload, now); - } - } - - fn handle_peer_traffic( - &mut self, - payload: &[u8], - peer: PeerSocket, - port: AllocationPort, - receiver: &mut TestNode, - now: Instant, - ) { - if let Some((client, channel)) = self - .span - .in_scope(|| self.inner.handle_peer_traffic(payload, peer, port)) - { - let full_length = firezone_relay::ChannelData::encode_header_to_slice( - channel, - payload.len() as u16, - &mut self.buffer[..4], - ); - self.buffer[4..full_length].copy_from_slice(payload); - - let receiving_socket = client.into_socket(); - let sending_socket = self.matching_listen_socket(receiving_socket).unwrap(); - receiver.receive( - receiving_socket, - sending_socket, - &self.buffer[..full_length], - now, - ); - } - } - - fn drain_messages( - &mut self, - a1: &mut TestNode, - a2: &mut TestNode, - now: Instant, - ) { - while let Some(command) = self.inner.next_command() { - match command { - firezone_relay::Command::SendMessage { payload, recipient } => { - let recipient = recipient.into_socket(); - let sending_socket = self.matching_listen_socket(recipient).unwrap(); - - if a1.local.contains(&recipient) { - a1.receive(recipient, sending_socket, &payload, now); - continue; - } - - if a2.local.contains(&recipient) { - a2.receive(recipient, sending_socket, &payload, now); - continue; - } - - panic!("Relay generated traffic for unknown client") - } - firezone_relay::Command::CreateAllocation { port, family } => { - self.allocations.insert((family, port)); - } - firezone_relay::Command::FreeAllocation { port, family } => { - self.allocations.remove(&(family, port)); - } - } - } - } - - fn make_credentials(&self, username: &str) -> (String, String) { - let expiry = SystemTime::now() + Duration::from_secs(60); - - let secs = expiry - .duration_since(SystemTime::UNIX_EPOCH) - .expect("expiry must be later than UNIX_EPOCH") - .as_secs(); - - let password = - firezone_relay::auth::generate_password(self.inner.auth_secret(), expiry, username); - - (format!("{secs}:{username}"), password) - } -} - -fn to_ip_stack(socket: RelaySocket) -> IpStack { - match socket { - RelaySocket::V4(v4) => IpStack::Ip4(*v4.ip()), - RelaySocket::V6(v6) => IpStack::Ip6(*v6.ip()), - RelaySocket::Dual { v4, v6 } => IpStack::Dual { - ip4: *v4.ip(), - ip6: *v6.ip(), - }, - } -} - -impl TestNode { - pub fn new(span: Span, node: Node, primary: &str) -> Self { - let primary = primary.parse().unwrap(); - - TestNode { - node, - span, - received_packets: vec![], - buffer: Box::new([0u8; 10_000]), - primary, - local: vec![primary], - events: Default::default(), - transmits: Default::default(), - } - } - - fn with_relays( - mut self, - username: &str, - to_remove: BTreeSet, - relays: &mut [(u64, TestRelay)], - now: Instant, - ) -> Self { - let turn_servers = relays - .iter() - .map(|(idx, relay)| { - let (username, password) = relay.make_credentials(username); - - ( - *idx, - relay.listen_addr, - username, - password, - "firezone".to_owned(), - ) - }) - .collect::>(); - - self.span - .in_scope(|| self.node.update_relays(to_remove, &turn_servers, now)); - - self - } - - fn is_connected_to(&self, other: &TestNode) -> bool { - self.node.connection_id(other.node.public_key()).is_some() - } - - fn ping(&mut self, src: IpAddr, dst: IpAddr, other: &TestNode, now: Instant) { - let id = self - .node - .connection_id(other.node.public_key()) - .expect("cannot ping not-connected node"); - - let transmit = self - .span - .in_scope(|| { - self.node.encapsulate( - id, - ip_packet::make::icmp_request_packet(src, dst, 1, 0).to_immutable(), - now, - ) - }) - .unwrap() - .unwrap() - .into_owned(); - - self.transmits.push_back(transmit); - } - - fn packets_from(&self, src: IpAddr) -> impl Iterator> { - self.received_packets - .iter() - .filter(move |p| p.source() == src) - } - - fn receive(&mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant) { - debug_assert!(self.local.contains(&local)); - - if let Some((_, packet)) = self - .span - .in_scope(|| { - self.node - .decapsulate(local, from, packet, now, self.buffer.as_mut()) - }) - .unwrap() - { - self.received_packets.push(packet.to_immutable().to_owned()) - } - } - - fn drain_events(&mut self, other: &mut TestNode, now: Instant) { - while let Some(v) = self.span.in_scope(|| self.node.poll_event()) { - self.events.push((v.clone(), now)); - - match v { - Event::NewIceCandidate { - connection, - candidate, - } => other - .span - .in_scope(|| other.node.add_remote_candidate(connection, candidate, now)), - Event::InvalidateIceCandidate { - connection, - candidate, - } => other - .span - .in_scope(|| other.node.remove_remote_candidate(connection, candidate)), - Event::ConnectionEstablished(_) - | Event::ConnectionFailed(_) - | Event::ConnectionClosed(_) => {} - }; - } - } - - fn drain_transmits( - &mut self, - other: &mut TestNode, - relays: &mut [(u64, TestRelay)], - firewall: &Firewall, - now: Instant, - ) { - for trans in iter::from_fn(|| self.node.poll_transmit()).chain(self.transmits.drain(..)) { - let payload = &trans.payload; - let dst = trans.dst; - - if let Some((_, relay)) = relays.iter_mut().find(|(_, r)| r.wants(trans.dst)) { - relay.handle_packet(payload, self.primary, dst, other, now); - continue; - } - - let Some(src) = trans.src else { - tracing::debug!(target: "router", %dst, "Unknown relay"); - continue; - }; - - // Wasn't traffic for the relay, let's check our firewall. - if firewall.blocked.contains(&(src, dst)) { - tracing::debug!(target: "firewall", %src, %dst, "Dropping packet"); - continue; - } - - if !other.local.contains(&dst) { - tracing::debug!(target: "router", %src, %dst, "Unknown destination"); - continue; - } - - // Firewall allowed traffic, let's dispatch it. - other.receive(dst, src, payload, now); - } - } -} - -fn handshake(client: &mut TestNode, server: &mut TestNode, clock: &Clock) { - let offer = client - .span - .in_scope(|| client.node.new_connection(1, clock.now, clock.now)); - let answer = server.span.in_scope(|| { - server - .node - .accept_connection(1, offer, client.node.public_key(), clock.now) - }); - client.span.in_scope(|| { - client - .node - .accept_answer(1, server.node.public_key(), answer, clock.now) - }); -} - -fn progress( - a1: &mut TestNode, - a2: &mut TestNode, - relays: &mut [(u64, TestRelay)], - firewall: &Firewall, - clock: &mut Clock, -) { - clock.tick(); - - a1.drain_events(a2, clock.now); - a2.drain_events(a1, clock.now); - - a1.drain_transmits(a2, relays, firewall, clock.now); - a2.drain_transmits(a1, relays, firewall, clock.now); - - for (_, relay) in relays.iter_mut() { - relay.drain_messages(a1, a2, clock.now); - } - - if let Some(timeout) = a1.node.poll_timeout() { - if clock.now >= timeout { - a1.span.in_scope(|| a1.node.handle_timeout(clock.now)); - } - } - - if let Some(timeout) = a2.node.poll_timeout() { - if clock.now >= timeout { - a2.span.in_scope(|| a2.node.handle_timeout(clock.now)); - } - } - - for (_, relay) in relays { - if let Some(timeout) = relay.inner.poll_timeout() { - if clock.now >= timeout { - relay - .span - .in_scope(|| relay.inner.handle_timeout(clock.now)) - } - } - } - - a1.drain_events(a2, clock.now); - a2.drain_events(a1, clock.now); -} diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index c62201166..aa621e0ea 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -22,7 +22,7 @@ ip_network = { version = "0.4", default-features = false } ip_network_table = { version = "0.2", default-features = false } itertools = { version = "0.13", default-features = false, features = ["use_std"] } proptest = { version = "1", optional = true } -rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } +rand = "0.8.5" rangemap = "1.5.1" secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 6feadd222..59a06cbc4 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -291,6 +291,7 @@ impl ClientState { pub(crate) fn new( private_key: impl Into, known_hosts: HashMap>, + seed: [u8; 32], ) -> Self { Self { awaiting_connection_details: Default::default(), @@ -303,7 +304,7 @@ impl ClientState { interface_config: Default::default(), buffered_packets: Default::default(), buffered_dns_queries: Default::default(), - node: ClientNode::new(private_key.into()), + node: ClientNode::new(private_key.into(), seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), @@ -1323,7 +1324,7 @@ impl IpProvider { #[cfg(test)] mod tests { use super::*; - use rand_core::OsRng; + use rand::rngs::OsRng; #[test] fn ignores_ip4_igmp_multicast() { @@ -1496,7 +1497,11 @@ mod tests { impl ClientState { pub fn for_test() -> ClientState { - ClientState::new(StaticSecret::random_from_rng(OsRng), HashMap::new()) + ClientState::new( + StaticSecret::random_from_rng(OsRng), + HashMap::new(), + rand::random(), + ) } } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 27af89cd0..01ecc8859 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -131,10 +131,10 @@ pub struct GatewayState { } impl GatewayState { - pub(crate) fn new(private_key: impl Into) -> Self { + pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into()), + node: ServerNode::new(private_key.into(), seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index dce7b1eb9..89db6157e 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -85,7 +85,7 @@ impl ClientTunnel { ) -> std::io::Result { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, - role_state: ClientState::new(private_key, known_hosts), + role_state: ClientState::new(private_key, known_hosts, rand::random()), write_buf: Box::new([0u8; DEFAULT_MTU + 16 + 20]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), @@ -183,7 +183,7 @@ impl GatewayTunnel { pub fn new(private_key: StaticSecret) -> std::io::Result { Ok(Self { io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?, - role_state: GatewayState::new(private_key), + role_state: GatewayState::new(private_key, rand::random()), write_buf: Box::new([0u8; DEFAULT_MTU + 20 + 16]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 2b1c649ff..166548e95 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -1,5 +1,5 @@ use super::{ - composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, sim_relay::*, + composite_strategy::CompositeStrategy, sim_client::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; use crate::dns::is_subdomain; @@ -8,11 +8,9 @@ use connlib_shared::{ client::{self, ResourceDescription}, GatewayId, RelayId, }, - proptest::*, DomainName, StaticSecret, }; use hickory_proto::rr::RecordType; -use prop::collection; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; use std::{ @@ -62,7 +60,7 @@ impl ReferenceStateMachine for ReferenceState { ( ref_client_host(Just(client_tunnel_ip4), Just(client_tunnel_ip6)), gateways_and_portal(), - collection::btree_map(relay_id(), relay_prototype(), 1..=2), + relays(), global_dns_records(), // Start out with a set of global DNS records so we have something to resolve outside of DNS resources. any::(), ) @@ -167,6 +165,10 @@ impl ReferenceStateMachine for ReferenceState { sample::select(resource_ids).prop_map(Transition::DeactivateResource) }) .with(1, roam_client()) + .with( + 1, + migrate_relays(Just(state.relays.keys().copied().collect())), // Always take down all relays because we can't know which one was sampled for the connection. + ) .with(1, Just(Transition::ReconnectPortal)) .with(1, Just(Transition::Idle)) .with_if_not_empty( @@ -376,19 +378,33 @@ impl ReferenceStateMachine for ReferenceState { .add_host(state.client.inner().id, &state.client)); // When roaming, we are not connected to any resource and wait for the next packet to re-establish a connection. - state.client.exec_mut(|client| { - client.connected_cidr_resources.clear(); - client.connected_dns_resources.clear(); - }); + state.client.exec_mut(|client| client.reset_connections()); } Transition::ReconnectPortal => { // Reconnecting to the portal should have no noticeable impact on the data plane. } + Transition::RelaysPresence { + disconnected, + online, + } => { + for rid in disconnected { + let disconnected_relay = + state.relays.remove(rid).expect("old host to be present"); + state.network.remove_host(&disconnected_relay); + } + + for (rid, online_relay) in online { + state.relays.insert(*rid, online_relay.clone()); + debug_assert!(state.network.add_host(*rid, online_relay)); + } + + // In case we were using the relays, all connections will be cut and require us to make a new one. + if state.drop_direct_client_traffic { + state.client.exec_mut(|client| client.reset_connections()); + } + } Transition::Idle => { - state.client.exec_mut(|client| { - client.connected_cidr_resources.clear(); - client.connected_dns_resources.clear(); - }); + state.client.exec_mut(|client| client.reset_connections()); } }; @@ -527,6 +543,26 @@ impl ReferenceStateMachine for ReferenceState { Transition::DeactivateResource(r) => { state.client.inner().all_resource_ids().contains(r) } + Transition::RelaysPresence { + disconnected, + online, + } => { + let all_old_are_present = disconnected + .iter() + .all(|rid| state.relays.contains_key(rid)); + let no_new_are_present = online.keys().all(|rid| !state.relays.contains_key(rid)); + + let mut additional_routes = RoutingTable::default(); + for (rid, relay) in online { + if !additional_routes.add_host(*rid, relay) { + return false; + } + } + + let route_overlap = state.network.overlaps_with(&additional_routes); + + all_old_are_present && no_new_are_present && !route_overlap + } Transition::Idle => true, } } @@ -560,7 +596,7 @@ pub(crate) fn private_key() -> impl Strategy { } #[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub(crate) struct PrivateKey([u8; 32]); +pub(crate) struct PrivateKey(pub [u8; 32]); impl From for StaticSecret { fn from(key: PrivateKey) -> Self { diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 5e6736d18..2da65b097 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -255,7 +255,7 @@ impl RefClient { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimClient { - let mut client_state = ClientState::new(self.key, self.known_hosts); + let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. let _ = client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, @@ -266,6 +266,11 @@ impl RefClient { SimClient::new(self.id, client_state) } + pub(crate) fn reset_connections(&mut self) { + self.connected_cidr_resources.clear(); + self.connected_dns_resources.clear(); + } + pub(crate) fn is_tunnel_ip(&self, ip: IpAddr) -> bool { match ip { IpAddr::V4(ip4) => self.tunnel_ip4 == ip4, diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 06656bffd..94614a021 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -99,7 +99,7 @@ impl RefGateway { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimGateway { - SimGateway::new(GatewayState::new(self.key)) + SimGateway::new(GatewayState::new(self.key, self.key.0)) // Cheating a bit here by reusing the key as seed. } } diff --git a/rust/connlib/tunnel/src/tests/sim_net.rs b/rust/connlib/tunnel/src/tests/sim_net.rs index 94d0d3228..eb241973b 100644 --- a/rust/connlib/tunnel/src/tests/sim_net.rs +++ b/rust/connlib/tunnel/src/tests/sim_net.rs @@ -236,6 +236,13 @@ impl RoutingTable { pub(crate) fn host_by_ip(&self, ip: IpAddr) -> Option { self.routes.exact_match(ip).copied() } + + pub(crate) fn overlaps_with(&self, other: &Self) -> bool { + other + .routes + .iter() + .any(|(route, _)| self.routes.exact_match(route).is_some()) + } } #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)] diff --git a/rust/connlib/tunnel/src/tests/sim_relay.rs b/rust/connlib/tunnel/src/tests/sim_relay.rs index 2c1a997c6..9cdeec379 100644 --- a/rust/connlib/tunnel/src/tests/sim_relay.rs +++ b/rust/connlib/tunnel/src/tests/sim_relay.rs @@ -196,7 +196,7 @@ impl SimRelay { } } -pub(crate) fn relay_prototype() -> impl Strategy> { +pub(crate) fn ref_relay_host() -> impl Strategy> { host( dual_ip_stack(), // For this test, our relays always run in dual-stack mode to ensure connectivity! Just(3478), diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index 97cd88df9..ecb859c36 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -1,16 +1,18 @@ use super::{ sim_gateway::{ref_gateway_host, RefGateway}, sim_net::Host, + sim_relay::ref_relay_host, stub_portal::StubPortal, }; use crate::client::{IPV4_RESOURCES, IPV6_RESOURCES}; use connlib_shared::{ messages::{ client::{ResourceDescriptionCidr, ResourceDescriptionDns, Site, SiteId}, - DnsServer, GatewayId, + DnsServer, GatewayId, RelayId, }, proptest::{ - any_ip_network, cidr_resource, dns_resource, domain_label, domain_name, gateway_id, site, + any_ip_network, cidr_resource, dns_resource, domain_label, domain_name, gateway_id, + relay_id, site, }, DomainName, }; @@ -195,6 +197,10 @@ pub(crate) fn gateways_and_portal() -> impl Strategy< ) } +pub(crate) fn relays() -> impl Strategy>> { + collection::btree_map(relay_id(), ref_relay_host(), 1..=2) +} + fn any_site(sites: HashSet) -> impl Strategy { sample::select(Vec::from_iter(sites)) } diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index ea40c2308..8484b4698 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -281,6 +281,53 @@ impl StateMachineTest for TunnelTest { c.sut.set_resources(all_resources); }); } + Transition::RelaysPresence { + disconnected, + online, + } => { + for rid in &disconnected { + let disconnected_relay = + state.relays.remove(rid).expect("old relay to be present"); + state.network.remove_host(&disconnected_relay); + } + + let online = online + .into_iter() + .map(|(rid, relay)| (rid, relay.map(SimRelay::new, debug_span!("relay", %rid)))) + .collect::>(); + + for (rid, relay) in &online { + debug_assert!(state.network.add_host(*rid, relay)); + } + + state.client.exec_mut({ + let disconnected = disconnected.clone(); + let online = online.iter(); + + move |c| { + c.sut.update_relays( + disconnected, + BTreeSet::from_iter(map_explode(online, "client")), + now, + ); + } + }); + for (id, gateway) in &mut state.gateways { + gateway.exec_mut({ + let disconnected = disconnected.clone(); + let online = online.iter(); + + move |g| { + g.sut.update_relays( + disconnected, + BTreeSet::from_iter(map_explode(online, &format!("gateway_{id}"))), + now, + ) + } + }); + } + state.relays.extend(online); + } Transition::Idle => { const IDLE_DURATION: Duration = Duration::from_secs(5 * 60); let cut_off = state.flux_capacitor.now::() + IDLE_DURATION; diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 735d39cac..80e8bde08 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -1,11 +1,17 @@ -use super::sim_net::{any_ip_stack, any_port}; +use super::{ + sim_net::{any_ip_stack, any_port, Host}, + strategies::relays, +}; use connlib_shared::{ - messages::{client::ResourceDescription, DnsServer, ResourceId}, + messages::{client::ResourceDescription, DnsServer, RelayId, ResourceId}, DomainName, }; use hickory_proto::rr::RecordType; use proptest::{prelude::*, sample}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::{ + collections::{BTreeMap, BTreeSet}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; /// The possible transitions of the state machine. #[derive(Clone, derivative::Derivative)] @@ -66,6 +72,12 @@ pub(crate) enum Transition { /// Reconnect to the portal. ReconnectPortal, + /// Simulate deployment of new relays. + RelaysPresence { + disconnected: BTreeSet, + online: BTreeMap>, + }, + /// Idle connlib for a while, forcing connection to auto-close. Idle, } @@ -171,3 +183,12 @@ pub(crate) fn roam_client() -> impl Strategy { port, }) } + +pub(crate) fn migrate_relays( + disconnected: impl Strategy>, +) -> impl Strategy { + (disconnected, relays()).prop_map(|(disconnected, online)| Transition::RelaysPresence { + disconnected, + online, + }) +} diff --git a/rust/snownet-tests/src/main.rs b/rust/snownet-tests/src/main.rs index 2e24bf0d6..4487c660d 100644 --- a/rust/snownet-tests/src/main.rs +++ b/rust/snownet-tests/src/main.rs @@ -87,7 +87,7 @@ async fn main() -> Result<()> { match role { Role::Dialer => { - let mut node = ClientNode::::new(private_key); + let mut node = ClientNode::::new(private_key, rand::random()); node.update_relays(BTreeSet::new(), &relays, Instant::now()); let offer = node.new_connection(1, Instant::now(), Instant::now()); @@ -167,7 +167,7 @@ async fn main() -> Result<()> { } } Role::Listener => { - let mut node = ServerNode::::new(private_key); + let mut node = ServerNode::::new(private_key, rand::random()); node.update_relays(BTreeSet::new(), &relays, Instant::now()); let offer = redis_connection