feat(snownet): introduce upsert_connection (#6937)

One of the key differences of the new control protocol designed in #6461
is that creating new connections is idempotent. We achieve this by
having the portal generate the ICE credentials and the preshared-key for
the WireGuard tunnel. As long as the ICE credentials don't change, we
don't need to make a new connection.

For `snownet`, this means we are deprecating the previous APIs for
making connections. The client-side APIs will have to stay around until
we merge the client-part of the new control protocol. The server-side
APIs will have to stay around until we remove backwards-compatibility
from the gateway.
This commit is contained in:
Thomas Eizinger
2024-10-08 12:37:42 +11:00
committed by GitHub
parent 4defb3b038
commit b7795dfa03
6 changed files with 167 additions and 19 deletions

View File

@@ -10,8 +10,10 @@ mod stats;
mod utils;
pub use allocation::RelaySocket;
#[allow(deprecated)] // Rust bug: `expect` doesn't seem to work on imports?
pub use node::{Answer, Offer};
pub use node::{
Answer, Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node,
Offer, Server, ServerNode, Transmit, HANDSHAKE_TIMEOUT,
Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, Server,
ServerNode, Transmit, HANDSHAKE_TIMEOUT,
};
pub use stats::{ConnectionStats, NodeStats};

View File

@@ -20,7 +20,6 @@ use std::borrow::Cow;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
use std::hash::Hash;
use std::marker::PhantomData;
use std::mem;
use std::ops::ControlFlow;
use std::time::{Duration, Instant};
@@ -45,8 +44,36 @@ pub type ServerNode<TId, RId> = Node<Server, TId, RId>;
/// Manages a set of wireguard connections for a client.
pub type ClientNode<TId, RId> = Node<Client, TId, RId>;
pub enum Server {}
pub enum Client {}
#[non_exhaustive]
pub struct Server {}
#[non_exhaustive]
pub struct Client {}
trait Mode {
fn new() -> Self;
fn is_client(&self) -> bool;
}
impl Mode for Server {
fn is_client(&self) -> bool {
false
}
fn new() -> Self {
Self {}
}
}
impl Mode for Client {
fn is_client(&self) -> bool {
true
}
fn new() -> Self {
Self {}
}
}
/// A node within a `snownet` network maintains connections to several other nodes.
///
@@ -96,7 +123,7 @@ pub struct Node<T, TId, RId> {
stats: NodeStats,
marker: PhantomData<T>,
mode: T,
rng: StdRng,
}
@@ -118,10 +145,12 @@ pub enum Error {
BadLocalAddress(#[from] str0m::error::IceError),
}
#[expect(private_bounds, reason = "We don't want `Mode` to be public API")]
impl<T, TId, RId> Node<T, TId, RId>
where
TId: Eq + Hash + Copy + Ord + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
T: Mode,
{
pub fn new(seed: [u8; 32]) -> Self {
let mut rng = StdRng::from_seed(seed);
@@ -133,7 +162,7 @@ where
session_id: SessionId::new(*public_key),
private_key,
public_key: *public_key,
marker: Default::default(),
mode: T::new(),
index: IndexLfsr::default(),
rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)),
host_candidates: Default::default(),
@@ -189,6 +218,65 @@ where
self.connections.len()
}
/// Upserts a connection to the given remote.
///
/// If we already have a connection with the same ICE credentials, this does nothing.
/// Otherwise, the existing connection is discarded and a new one will be created.
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
pub fn upsert_connection(
&mut self,
cid: TId,
remote: PublicKey,
session_key: Secret<[u8; 32]>,
local_creds: Credentials,
remote_creds: Credentials,
now: Instant,
) {
let local_creds = local_creds.into();
let remote_creds = remote_creds.into();
if self.connections.initial.contains_key(&cid) {
debug_assert!(false, "The new `upsert_connection` API is incompatible with the previous `new_connection` API");
return;
}
if self
.connections
.get_established_mut(&cid)
.is_some_and(|c| c.agent.local_credentials() == &local_creds)
{
tracing::debug!("Already got a connection");
return;
}
let selected_relay = self.sample_relay();
let mut agent = new_agent();
agent.set_controlling(self.mode.is_client());
agent.set_local_credentials(local_creds);
agent.set_remote_credentials(remote_creds);
self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent);
let connection = self.init_connection(
cid,
agent,
remote,
*session_key.expose_secret(),
selected_relay,
now,
now,
);
let existing = self.connections.established.insert(cid, connection);
if existing.is_some() {
tracing::info!("Replaced existing connection");
} else {
tracing::info!("Created new connection");
}
}
pub fn public_key(&self) -> PublicKey {
self.public_key
}
@@ -853,6 +941,8 @@ where
/// The returned [`Offer`] must be passed to the remote via a signalling channel.
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
#[must_use]
#[deprecated]
#[expect(deprecated)]
pub fn new_connection(&mut self, cid: TId, intent_sent_at: Instant, now: Instant) -> Offer {
if self.connections.initial.remove(&cid).is_some() {
tracing::info!("Replacing existing initial connection");
@@ -902,6 +992,8 @@ where
/// Accept an [`Answer`] from the remote for a connection previously created via [`Node::new_connection`].
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
#[deprecated]
#[expect(deprecated)]
pub fn accept_answer(&mut self, cid: TId, remote: PublicKey, answer: Answer, now: Instant) {
let Some(initial) = self.connections.initial.remove(&cid) else {
tracing::debug!("No initial connection state, ignoring answer"); // This can happen if the connection setup timed out.
@@ -948,6 +1040,8 @@ where
/// The returned [`Answer`] must be passed to the remote via a signalling channel.
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
#[must_use]
#[deprecated]
#[expect(deprecated)]
pub fn accept_connection(
&mut self,
cid: TId,
@@ -1334,12 +1428,14 @@ fn remove_local_candidate<TId>(
}
}
#[deprecated]
pub struct Offer {
/// The Wireguard session key for a connection.
pub session_key: Secret<[u8; 32]>,
pub credentials: Credentials,
}
#[deprecated]
pub struct Answer {
pub credentials: Credentials,
}
@@ -1351,6 +1447,16 @@ pub struct Credentials {
pub password: String,
}
#[doc(hidden)] // Not public API.
impl From<Credentials> for str0m::IceCreds {
fn from(value: Credentials) -> Self {
str0m::IceCreds {
ufrag: value.username,
pass: value.password,
}
}
}
#[derive(Debug, PartialEq, Clone)]
pub enum Event<TId> {
/// We created a new candidate for this connection and ask to signal it to the remote party.

View File

@@ -1,4 +1,5 @@
use snownet::{Answer, ClientNode, Event, ServerNode};
use secrecy::Secret;
use snownet::{ClientNode, Credentials, Event, ServerNode};
use std::{
iter,
net::{IpAddr, Ipv4Addr, SocketAddr},
@@ -7,6 +8,7 @@ use std::{
use str0m::{net::Protocol, Candidate};
#[test]
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
fn connection_times_out_after_20_seconds() {
let (mut alice, _) = alice_and_bob();
@@ -24,12 +26,9 @@ fn connection_without_candidates_times_out_after_10_seconds() {
let start = Instant::now();
let (mut alice, mut bob) = alice_and_bob();
let answer = send_offer(&mut alice, &mut bob, start);
handshake(&mut alice, &mut bob, start);
let accepted_at = start + Duration::from_secs(1);
alice.accept_answer(1, bob.public_key(), answer, accepted_at);
alice.handle_timeout(accepted_at + Duration::from_secs(10));
alice.handle_timeout(start + Duration::from_secs(10));
assert_eq!(alice.poll_event().unwrap(), Event::ConnectionFailed(1));
}
@@ -40,14 +39,12 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() {
let start = Instant::now();
let (mut alice, mut bob) = alice_and_bob();
let answer = send_offer(&mut alice, &mut bob, start);
handshake(&mut alice, &mut bob, start);
let accepted_at = start + Duration::from_secs(1);
alice.accept_answer(1, bob.public_key(), answer, accepted_at);
alice.add_local_host_candidate(s("10.0.0.2:4444")).unwrap();
alice.add_remote_candidate(1, host("10.0.0.1:4444"), accepted_at);
alice.add_remote_candidate(1, host("10.0.0.1:4444"), start);
alice.handle_timeout(accepted_at + Duration::from_secs(10));
alice.handle_timeout(start + Duration::from_secs(10));
let any_failed =
iter::from_fn(|| alice.poll_event()).any(|e| matches!(e, Event::ConnectionFailed(_)));
@@ -56,6 +53,7 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() {
}
#[test]
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
fn answer_after_stale_connection_does_not_panic() {
let start = Instant::now();
@@ -69,6 +67,7 @@ fn answer_after_stale_connection_does_not_panic() {
}
#[test]
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
fn only_generate_candidate_event_after_answer() {
let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000);
@@ -105,16 +104,48 @@ fn alice_and_bob() -> (ClientNode<u64, u64>, ServerNode<u64, u64>) {
(alice, bob)
}
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
fn send_offer(
alice: &mut ClientNode<u64, u64>,
bob: &mut ServerNode<u64, u64>,
now: Instant,
) -> Answer {
) -> snownet::Answer {
let offer = alice.new_connection(1, Instant::now(), now);
bob.accept_connection(1, offer, alice.public_key(), now)
}
fn handshake(alice: &mut ClientNode<u64, u64>, bob: &mut ServerNode<u64, u64>, now: Instant) {
alice.upsert_connection(
1,
bob.public_key(),
Secret::new([0u8; 32]),
Credentials {
username: "foo".to_owned(),
password: "foo".to_owned(),
},
Credentials {
username: "bar".to_owned(),
password: "bar".to_owned(),
},
now,
);
bob.upsert_connection(
1,
alice.public_key(),
Secret::new([0u8; 32]),
Credentials {
username: "bar".to_owned(),
password: "bar".to_owned(),
},
Credentials {
username: "foo".to_owned(),
password: "foo".to_owned(),
},
now,
);
}
fn host(socket: &str) -> String {
Candidate::host(s(socket), Protocol::Udp)
.unwrap()

View File

@@ -179,6 +179,7 @@ impl ClientTunnel {
.on_routing_details(resource_id, gateway_id, site_id, Instant::now())
}
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
pub fn received_offer_response(
&mut self,
resource_id: ResourceId,
@@ -517,6 +518,7 @@ impl ClientState {
}
#[tracing::instrument(level = "trace", skip_all, fields(%resource_id))]
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
pub fn accept_answer(
&mut self,
answer: snownet::Answer,
@@ -539,6 +541,10 @@ impl ClientState {
///
/// In a nutshell, this tells us which gateway in which site to use for the given resource.
#[tracing::instrument(level = "debug", skip_all, fields(%resource_id, %gateway_id))]
#[expect(
deprecated,
reason = "Will be refactored when deprecated control protocol is shipped"
)]
pub fn on_routing_details(
&mut self,
resource_id: ResourceId,

View File

@@ -35,6 +35,7 @@ impl GatewayTunnel {
}
/// Accept a connection request from a client.
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
pub fn accept(
&mut self,
client_id: ClientId,
@@ -249,6 +250,7 @@ impl GatewayState {
}
/// Accept a connection request from a client.
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
pub fn accept(
&mut self,
client_id: ClientId,

View File

@@ -709,6 +709,7 @@ impl TunnelTest {
c.ipv6_routes = config.ipv6_routes;
});
}
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
ClientEvent::RequestConnection {
gateway_id,
offer,