From b7795dfa03267fde7779fcae942eb427e30c8bd8 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 8 Oct 2024 12:37:42 +1100 Subject: [PATCH] feat(snownet): introduce `upsert_connection` (#6937) One of the key differences of the new control protocol designed in #6461 is that creating new connections is idempotent. We achieve this by having the portal generate the ICE credentials and the preshared-key for the WireGuard tunnel. As long as the ICE credentials don't change, we don't need to make a new connection. For `snownet`, this means we are deprecating the previous APIs for making connections. The client-side APIs will have to stay around until we merge the client-part of the new control protocol. The server-side APIs will have to stay around until we remove backwards-compatibility from the gateway. --- rust/connlib/snownet/src/lib.rs | 6 +- rust/connlib/snownet/src/node.rs | 116 +++++++++++++++++++++++++-- rust/connlib/snownet/tests/lib.rs | 55 ++++++++++--- rust/connlib/tunnel/src/client.rs | 6 ++ rust/connlib/tunnel/src/gateway.rs | 2 + rust/connlib/tunnel/src/tests/sut.rs | 1 + 6 files changed, 167 insertions(+), 19 deletions(-) diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index 454b9d715..a0aa1dc12 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -10,8 +10,10 @@ mod stats; mod utils; pub use allocation::RelaySocket; +#[allow(deprecated)] // Rust bug: `expect` doesn't seem to work on imports? +pub use node::{Answer, Offer}; pub use node::{ - Answer, Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, - Offer, Server, ServerNode, Transmit, HANDSHAKE_TIMEOUT, + Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, Server, + ServerNode, Transmit, HANDSHAKE_TIMEOUT, }; pub use stats::{ConnectionStats, NodeStats}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 054815e82..794dd383a 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -20,7 +20,6 @@ use std::borrow::Cow; use std::collections::btree_map::Entry; use std::collections::{BTreeMap, BTreeSet}; use std::hash::Hash; -use std::marker::PhantomData; use std::mem; use std::ops::ControlFlow; use std::time::{Duration, Instant}; @@ -45,8 +44,36 @@ pub type ServerNode = Node; /// Manages a set of wireguard connections for a client. pub type ClientNode = Node; -pub enum Server {} -pub enum Client {} +#[non_exhaustive] +pub struct Server {} + +#[non_exhaustive] +pub struct Client {} + +trait Mode { + fn new() -> Self; + fn is_client(&self) -> bool; +} + +impl Mode for Server { + fn is_client(&self) -> bool { + false + } + + fn new() -> Self { + Self {} + } +} + +impl Mode for Client { + fn is_client(&self) -> bool { + true + } + + fn new() -> Self { + Self {} + } +} /// A node within a `snownet` network maintains connections to several other nodes. /// @@ -96,7 +123,7 @@ pub struct Node { stats: NodeStats, - marker: PhantomData, + mode: T, rng: StdRng, } @@ -118,10 +145,12 @@ pub enum Error { BadLocalAddress(#[from] str0m::error::IceError), } +#[expect(private_bounds, reason = "We don't want `Mode` to be public API")] impl Node where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, + T: Mode, { pub fn new(seed: [u8; 32]) -> Self { let mut rng = StdRng::from_seed(seed); @@ -133,7 +162,7 @@ where session_id: SessionId::new(*public_key), private_key, public_key: *public_key, - marker: Default::default(), + mode: T::new(), index: IndexLfsr::default(), rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)), host_candidates: Default::default(), @@ -189,6 +218,65 @@ where self.connections.len() } + /// Upserts a connection to the given remote. + /// + /// If we already have a connection with the same ICE credentials, this does nothing. + /// Otherwise, the existing connection is discarded and a new one will be created. + #[tracing::instrument(level = "info", skip_all, fields(%cid))] + pub fn upsert_connection( + &mut self, + cid: TId, + remote: PublicKey, + session_key: Secret<[u8; 32]>, + local_creds: Credentials, + remote_creds: Credentials, + now: Instant, + ) { + let local_creds = local_creds.into(); + let remote_creds = remote_creds.into(); + + if self.connections.initial.contains_key(&cid) { + debug_assert!(false, "The new `upsert_connection` API is incompatible with the previous `new_connection` API"); + return; + } + + if self + .connections + .get_established_mut(&cid) + .is_some_and(|c| c.agent.local_credentials() == &local_creds) + { + tracing::debug!("Already got a connection"); + return; + } + + let selected_relay = self.sample_relay(); + + let mut agent = new_agent(); + agent.set_controlling(self.mode.is_client()); + agent.set_local_credentials(local_creds); + agent.set_remote_credentials(remote_creds); + + self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent); + + let connection = self.init_connection( + cid, + agent, + remote, + *session_key.expose_secret(), + selected_relay, + now, + now, + ); + + let existing = self.connections.established.insert(cid, connection); + + if existing.is_some() { + tracing::info!("Replaced existing connection"); + } else { + tracing::info!("Created new connection"); + } + } + pub fn public_key(&self) -> PublicKey { self.public_key } @@ -853,6 +941,8 @@ where /// The returned [`Offer`] must be passed to the remote via a signalling channel. #[tracing::instrument(level = "info", skip_all, fields(%cid))] #[must_use] + #[deprecated] + #[expect(deprecated)] pub fn new_connection(&mut self, cid: TId, intent_sent_at: Instant, now: Instant) -> Offer { if self.connections.initial.remove(&cid).is_some() { tracing::info!("Replacing existing initial connection"); @@ -902,6 +992,8 @@ where /// Accept an [`Answer`] from the remote for a connection previously created via [`Node::new_connection`]. #[tracing::instrument(level = "info", skip_all, fields(%cid))] + #[deprecated] + #[expect(deprecated)] pub fn accept_answer(&mut self, cid: TId, remote: PublicKey, answer: Answer, now: Instant) { let Some(initial) = self.connections.initial.remove(&cid) else { tracing::debug!("No initial connection state, ignoring answer"); // This can happen if the connection setup timed out. @@ -948,6 +1040,8 @@ where /// The returned [`Answer`] must be passed to the remote via a signalling channel. #[tracing::instrument(level = "info", skip_all, fields(%cid))] #[must_use] + #[deprecated] + #[expect(deprecated)] pub fn accept_connection( &mut self, cid: TId, @@ -1334,12 +1428,14 @@ fn remove_local_candidate( } } +#[deprecated] pub struct Offer { /// The Wireguard session key for a connection. pub session_key: Secret<[u8; 32]>, pub credentials: Credentials, } +#[deprecated] pub struct Answer { pub credentials: Credentials, } @@ -1351,6 +1447,16 @@ pub struct Credentials { pub password: String, } +#[doc(hidden)] // Not public API. +impl From for str0m::IceCreds { + fn from(value: Credentials) -> Self { + str0m::IceCreds { + ufrag: value.username, + pass: value.password, + } + } +} + #[derive(Debug, PartialEq, Clone)] pub enum Event { /// We created a new candidate for this connection and ask to signal it to the remote party. diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index d3e2dfdbc..2e0b0f53e 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,4 +1,5 @@ -use snownet::{Answer, ClientNode, Event, ServerNode}; +use secrecy::Secret; +use snownet::{ClientNode, Credentials, Event, ServerNode}; use std::{ iter, net::{IpAddr, Ipv4Addr, SocketAddr}, @@ -7,6 +8,7 @@ use std::{ use str0m::{net::Protocol, Candidate}; #[test] +#[expect(deprecated, reason = "Will be deleted together with deprecated API")] fn connection_times_out_after_20_seconds() { let (mut alice, _) = alice_and_bob(); @@ -24,12 +26,9 @@ fn connection_without_candidates_times_out_after_10_seconds() { let start = Instant::now(); let (mut alice, mut bob) = alice_and_bob(); - let answer = send_offer(&mut alice, &mut bob, start); + handshake(&mut alice, &mut bob, start); - let accepted_at = start + Duration::from_secs(1); - alice.accept_answer(1, bob.public_key(), answer, accepted_at); - - alice.handle_timeout(accepted_at + Duration::from_secs(10)); + alice.handle_timeout(start + Duration::from_secs(10)); assert_eq!(alice.poll_event().unwrap(), Event::ConnectionFailed(1)); } @@ -40,14 +39,12 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() { let start = Instant::now(); let (mut alice, mut bob) = alice_and_bob(); - let answer = send_offer(&mut alice, &mut bob, start); + handshake(&mut alice, &mut bob, start); - let accepted_at = start + Duration::from_secs(1); - alice.accept_answer(1, bob.public_key(), answer, accepted_at); alice.add_local_host_candidate(s("10.0.0.2:4444")).unwrap(); - alice.add_remote_candidate(1, host("10.0.0.1:4444"), accepted_at); + alice.add_remote_candidate(1, host("10.0.0.1:4444"), start); - alice.handle_timeout(accepted_at + Duration::from_secs(10)); + alice.handle_timeout(start + Duration::from_secs(10)); let any_failed = iter::from_fn(|| alice.poll_event()).any(|e| matches!(e, Event::ConnectionFailed(_))); @@ -56,6 +53,7 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() { } #[test] +#[expect(deprecated, reason = "Will be deleted together with deprecated API")] fn answer_after_stale_connection_does_not_panic() { let start = Instant::now(); @@ -69,6 +67,7 @@ fn answer_after_stale_connection_does_not_panic() { } #[test] +#[expect(deprecated, reason = "Will be deleted together with deprecated API")] fn only_generate_candidate_event_after_answer() { let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000); @@ -105,16 +104,48 @@ fn alice_and_bob() -> (ClientNode, ServerNode) { (alice, bob) } +#[expect(deprecated, reason = "Will be deleted together with deprecated API")] fn send_offer( alice: &mut ClientNode, bob: &mut ServerNode, now: Instant, -) -> Answer { +) -> snownet::Answer { let offer = alice.new_connection(1, Instant::now(), now); bob.accept_connection(1, offer, alice.public_key(), now) } +fn handshake(alice: &mut ClientNode, bob: &mut ServerNode, now: Instant) { + alice.upsert_connection( + 1, + bob.public_key(), + Secret::new([0u8; 32]), + Credentials { + username: "foo".to_owned(), + password: "foo".to_owned(), + }, + Credentials { + username: "bar".to_owned(), + password: "bar".to_owned(), + }, + now, + ); + bob.upsert_connection( + 1, + alice.public_key(), + Secret::new([0u8; 32]), + Credentials { + username: "bar".to_owned(), + password: "bar".to_owned(), + }, + Credentials { + username: "foo".to_owned(), + password: "foo".to_owned(), + }, + now, + ); +} + fn host(socket: &str) -> String { Candidate::host(s(socket), Protocol::Udp) .unwrap() diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 87d3a3165..1d41194c7 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -179,6 +179,7 @@ impl ClientTunnel { .on_routing_details(resource_id, gateway_id, site_id, Instant::now()) } + #[expect(deprecated, reason = "Will be deleted together with deprecated API")] pub fn received_offer_response( &mut self, resource_id: ResourceId, @@ -517,6 +518,7 @@ impl ClientState { } #[tracing::instrument(level = "trace", skip_all, fields(%resource_id))] + #[expect(deprecated, reason = "Will be deleted together with deprecated API")] pub fn accept_answer( &mut self, answer: snownet::Answer, @@ -539,6 +541,10 @@ impl ClientState { /// /// In a nutshell, this tells us which gateway in which site to use for the given resource. #[tracing::instrument(level = "debug", skip_all, fields(%resource_id, %gateway_id))] + #[expect( + deprecated, + reason = "Will be refactored when deprecated control protocol is shipped" + )] pub fn on_routing_details( &mut self, resource_id: ResourceId, diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index bbc904c6d..ca8bccb97 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -35,6 +35,7 @@ impl GatewayTunnel { } /// Accept a connection request from a client. + #[expect(deprecated, reason = "Will be deleted together with deprecated API")] pub fn accept( &mut self, client_id: ClientId, @@ -249,6 +250,7 @@ impl GatewayState { } /// Accept a connection request from a client. + #[expect(deprecated, reason = "Will be deleted together with deprecated API")] pub fn accept( &mut self, client_id: ClientId, diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index c9e4a0475..2c4f2b618 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -709,6 +709,7 @@ impl TunnelTest { c.ipv6_routes = config.ipv6_routes; }); } + #[expect(deprecated, reason = "Will be deleted together with deprecated API")] ClientEvent::RequestConnection { gateway_id, offer,