mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
feat(snownet): introduce upsert_connection (#6937)
One of the key differences of the new control protocol designed in #6461 is that creating new connections is idempotent. We achieve this by having the portal generate the ICE credentials and the preshared-key for the WireGuard tunnel. As long as the ICE credentials don't change, we don't need to make a new connection. For `snownet`, this means we are deprecating the previous APIs for making connections. The client-side APIs will have to stay around until we merge the client-part of the new control protocol. The server-side APIs will have to stay around until we remove backwards-compatibility from the gateway.
This commit is contained in:
@@ -10,8 +10,10 @@ mod stats;
|
||||
mod utils;
|
||||
|
||||
pub use allocation::RelaySocket;
|
||||
#[allow(deprecated)] // Rust bug: `expect` doesn't seem to work on imports?
|
||||
pub use node::{Answer, Offer};
|
||||
pub use node::{
|
||||
Answer, Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node,
|
||||
Offer, Server, ServerNode, Transmit, HANDSHAKE_TIMEOUT,
|
||||
Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, Server,
|
||||
ServerNode, Transmit, HANDSHAKE_TIMEOUT,
|
||||
};
|
||||
pub use stats::{ConnectionStats, NodeStats};
|
||||
|
||||
@@ -20,7 +20,6 @@ use std::borrow::Cow;
|
||||
use std::collections::btree_map::Entry;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::hash::Hash;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use std::ops::ControlFlow;
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -45,8 +44,36 @@ pub type ServerNode<TId, RId> = Node<Server, TId, RId>;
|
||||
/// Manages a set of wireguard connections for a client.
|
||||
pub type ClientNode<TId, RId> = Node<Client, TId, RId>;
|
||||
|
||||
pub enum Server {}
|
||||
pub enum Client {}
|
||||
#[non_exhaustive]
|
||||
pub struct Server {}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub struct Client {}
|
||||
|
||||
trait Mode {
|
||||
fn new() -> Self;
|
||||
fn is_client(&self) -> bool;
|
||||
}
|
||||
|
||||
impl Mode for Server {
|
||||
fn is_client(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl Mode for Client {
|
||||
fn is_client(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
/// A node within a `snownet` network maintains connections to several other nodes.
|
||||
///
|
||||
@@ -96,7 +123,7 @@ pub struct Node<T, TId, RId> {
|
||||
|
||||
stats: NodeStats,
|
||||
|
||||
marker: PhantomData<T>,
|
||||
mode: T,
|
||||
rng: StdRng,
|
||||
}
|
||||
|
||||
@@ -118,10 +145,12 @@ pub enum Error {
|
||||
BadLocalAddress(#[from] str0m::error::IceError),
|
||||
}
|
||||
|
||||
#[expect(private_bounds, reason = "We don't want `Mode` to be public API")]
|
||||
impl<T, TId, RId> Node<T, TId, RId>
|
||||
where
|
||||
TId: Eq + Hash + Copy + Ord + fmt::Display,
|
||||
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
|
||||
T: Mode,
|
||||
{
|
||||
pub fn new(seed: [u8; 32]) -> Self {
|
||||
let mut rng = StdRng::from_seed(seed);
|
||||
@@ -133,7 +162,7 @@ where
|
||||
session_id: SessionId::new(*public_key),
|
||||
private_key,
|
||||
public_key: *public_key,
|
||||
marker: Default::default(),
|
||||
mode: T::new(),
|
||||
index: IndexLfsr::default(),
|
||||
rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)),
|
||||
host_candidates: Default::default(),
|
||||
@@ -189,6 +218,65 @@ where
|
||||
self.connections.len()
|
||||
}
|
||||
|
||||
/// Upserts a connection to the given remote.
|
||||
///
|
||||
/// If we already have a connection with the same ICE credentials, this does nothing.
|
||||
/// Otherwise, the existing connection is discarded and a new one will be created.
|
||||
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
|
||||
pub fn upsert_connection(
|
||||
&mut self,
|
||||
cid: TId,
|
||||
remote: PublicKey,
|
||||
session_key: Secret<[u8; 32]>,
|
||||
local_creds: Credentials,
|
||||
remote_creds: Credentials,
|
||||
now: Instant,
|
||||
) {
|
||||
let local_creds = local_creds.into();
|
||||
let remote_creds = remote_creds.into();
|
||||
|
||||
if self.connections.initial.contains_key(&cid) {
|
||||
debug_assert!(false, "The new `upsert_connection` API is incompatible with the previous `new_connection` API");
|
||||
return;
|
||||
}
|
||||
|
||||
if self
|
||||
.connections
|
||||
.get_established_mut(&cid)
|
||||
.is_some_and(|c| c.agent.local_credentials() == &local_creds)
|
||||
{
|
||||
tracing::debug!("Already got a connection");
|
||||
return;
|
||||
}
|
||||
|
||||
let selected_relay = self.sample_relay();
|
||||
|
||||
let mut agent = new_agent();
|
||||
agent.set_controlling(self.mode.is_client());
|
||||
agent.set_local_credentials(local_creds);
|
||||
agent.set_remote_credentials(remote_creds);
|
||||
|
||||
self.seed_agent_with_local_candidates(cid, selected_relay, &mut agent);
|
||||
|
||||
let connection = self.init_connection(
|
||||
cid,
|
||||
agent,
|
||||
remote,
|
||||
*session_key.expose_secret(),
|
||||
selected_relay,
|
||||
now,
|
||||
now,
|
||||
);
|
||||
|
||||
let existing = self.connections.established.insert(cid, connection);
|
||||
|
||||
if existing.is_some() {
|
||||
tracing::info!("Replaced existing connection");
|
||||
} else {
|
||||
tracing::info!("Created new connection");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
self.public_key
|
||||
}
|
||||
@@ -853,6 +941,8 @@ where
|
||||
/// The returned [`Offer`] must be passed to the remote via a signalling channel.
|
||||
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
|
||||
#[must_use]
|
||||
#[deprecated]
|
||||
#[expect(deprecated)]
|
||||
pub fn new_connection(&mut self, cid: TId, intent_sent_at: Instant, now: Instant) -> Offer {
|
||||
if self.connections.initial.remove(&cid).is_some() {
|
||||
tracing::info!("Replacing existing initial connection");
|
||||
@@ -902,6 +992,8 @@ where
|
||||
|
||||
/// Accept an [`Answer`] from the remote for a connection previously created via [`Node::new_connection`].
|
||||
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
|
||||
#[deprecated]
|
||||
#[expect(deprecated)]
|
||||
pub fn accept_answer(&mut self, cid: TId, remote: PublicKey, answer: Answer, now: Instant) {
|
||||
let Some(initial) = self.connections.initial.remove(&cid) else {
|
||||
tracing::debug!("No initial connection state, ignoring answer"); // This can happen if the connection setup timed out.
|
||||
@@ -948,6 +1040,8 @@ where
|
||||
/// The returned [`Answer`] must be passed to the remote via a signalling channel.
|
||||
#[tracing::instrument(level = "info", skip_all, fields(%cid))]
|
||||
#[must_use]
|
||||
#[deprecated]
|
||||
#[expect(deprecated)]
|
||||
pub fn accept_connection(
|
||||
&mut self,
|
||||
cid: TId,
|
||||
@@ -1334,12 +1428,14 @@ fn remove_local_candidate<TId>(
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated]
|
||||
pub struct Offer {
|
||||
/// The Wireguard session key for a connection.
|
||||
pub session_key: Secret<[u8; 32]>,
|
||||
pub credentials: Credentials,
|
||||
}
|
||||
|
||||
#[deprecated]
|
||||
pub struct Answer {
|
||||
pub credentials: Credentials,
|
||||
}
|
||||
@@ -1351,6 +1447,16 @@ pub struct Credentials {
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[doc(hidden)] // Not public API.
|
||||
impl From<Credentials> for str0m::IceCreds {
|
||||
fn from(value: Credentials) -> Self {
|
||||
str0m::IceCreds {
|
||||
ufrag: value.username,
|
||||
pass: value.password,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum Event<TId> {
|
||||
/// We created a new candidate for this connection and ask to signal it to the remote party.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use snownet::{Answer, ClientNode, Event, ServerNode};
|
||||
use secrecy::Secret;
|
||||
use snownet::{ClientNode, Credentials, Event, ServerNode};
|
||||
use std::{
|
||||
iter,
|
||||
net::{IpAddr, Ipv4Addr, SocketAddr},
|
||||
@@ -7,6 +8,7 @@ use std::{
|
||||
use str0m::{net::Protocol, Candidate};
|
||||
|
||||
#[test]
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
fn connection_times_out_after_20_seconds() {
|
||||
let (mut alice, _) = alice_and_bob();
|
||||
|
||||
@@ -24,12 +26,9 @@ fn connection_without_candidates_times_out_after_10_seconds() {
|
||||
let start = Instant::now();
|
||||
|
||||
let (mut alice, mut bob) = alice_and_bob();
|
||||
let answer = send_offer(&mut alice, &mut bob, start);
|
||||
handshake(&mut alice, &mut bob, start);
|
||||
|
||||
let accepted_at = start + Duration::from_secs(1);
|
||||
alice.accept_answer(1, bob.public_key(), answer, accepted_at);
|
||||
|
||||
alice.handle_timeout(accepted_at + Duration::from_secs(10));
|
||||
alice.handle_timeout(start + Duration::from_secs(10));
|
||||
|
||||
assert_eq!(alice.poll_event().unwrap(), Event::ConnectionFailed(1));
|
||||
}
|
||||
@@ -40,14 +39,12 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() {
|
||||
let start = Instant::now();
|
||||
|
||||
let (mut alice, mut bob) = alice_and_bob();
|
||||
let answer = send_offer(&mut alice, &mut bob, start);
|
||||
handshake(&mut alice, &mut bob, start);
|
||||
|
||||
let accepted_at = start + Duration::from_secs(1);
|
||||
alice.accept_answer(1, bob.public_key(), answer, accepted_at);
|
||||
alice.add_local_host_candidate(s("10.0.0.2:4444")).unwrap();
|
||||
alice.add_remote_candidate(1, host("10.0.0.1:4444"), accepted_at);
|
||||
alice.add_remote_candidate(1, host("10.0.0.1:4444"), start);
|
||||
|
||||
alice.handle_timeout(accepted_at + Duration::from_secs(10));
|
||||
alice.handle_timeout(start + Duration::from_secs(10));
|
||||
|
||||
let any_failed =
|
||||
iter::from_fn(|| alice.poll_event()).any(|e| matches!(e, Event::ConnectionFailed(_)));
|
||||
@@ -56,6 +53,7 @@ fn connection_with_candidates_does_not_time_out_after_10_seconds() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
fn answer_after_stale_connection_does_not_panic() {
|
||||
let start = Instant::now();
|
||||
|
||||
@@ -69,6 +67,7 @@ fn answer_after_stale_connection_does_not_panic() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
fn only_generate_candidate_event_after_answer() {
|
||||
let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000);
|
||||
|
||||
@@ -105,16 +104,48 @@ fn alice_and_bob() -> (ClientNode<u64, u64>, ServerNode<u64, u64>) {
|
||||
(alice, bob)
|
||||
}
|
||||
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
fn send_offer(
|
||||
alice: &mut ClientNode<u64, u64>,
|
||||
bob: &mut ServerNode<u64, u64>,
|
||||
now: Instant,
|
||||
) -> Answer {
|
||||
) -> snownet::Answer {
|
||||
let offer = alice.new_connection(1, Instant::now(), now);
|
||||
|
||||
bob.accept_connection(1, offer, alice.public_key(), now)
|
||||
}
|
||||
|
||||
fn handshake(alice: &mut ClientNode<u64, u64>, bob: &mut ServerNode<u64, u64>, now: Instant) {
|
||||
alice.upsert_connection(
|
||||
1,
|
||||
bob.public_key(),
|
||||
Secret::new([0u8; 32]),
|
||||
Credentials {
|
||||
username: "foo".to_owned(),
|
||||
password: "foo".to_owned(),
|
||||
},
|
||||
Credentials {
|
||||
username: "bar".to_owned(),
|
||||
password: "bar".to_owned(),
|
||||
},
|
||||
now,
|
||||
);
|
||||
bob.upsert_connection(
|
||||
1,
|
||||
alice.public_key(),
|
||||
Secret::new([0u8; 32]),
|
||||
Credentials {
|
||||
username: "bar".to_owned(),
|
||||
password: "bar".to_owned(),
|
||||
},
|
||||
Credentials {
|
||||
username: "foo".to_owned(),
|
||||
password: "foo".to_owned(),
|
||||
},
|
||||
now,
|
||||
);
|
||||
}
|
||||
|
||||
fn host(socket: &str) -> String {
|
||||
Candidate::host(s(socket), Protocol::Udp)
|
||||
.unwrap()
|
||||
|
||||
@@ -179,6 +179,7 @@ impl ClientTunnel {
|
||||
.on_routing_details(resource_id, gateway_id, site_id, Instant::now())
|
||||
}
|
||||
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
pub fn received_offer_response(
|
||||
&mut self,
|
||||
resource_id: ResourceId,
|
||||
@@ -517,6 +518,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all, fields(%resource_id))]
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
pub fn accept_answer(
|
||||
&mut self,
|
||||
answer: snownet::Answer,
|
||||
@@ -539,6 +541,10 @@ impl ClientState {
|
||||
///
|
||||
/// In a nutshell, this tells us which gateway in which site to use for the given resource.
|
||||
#[tracing::instrument(level = "debug", skip_all, fields(%resource_id, %gateway_id))]
|
||||
#[expect(
|
||||
deprecated,
|
||||
reason = "Will be refactored when deprecated control protocol is shipped"
|
||||
)]
|
||||
pub fn on_routing_details(
|
||||
&mut self,
|
||||
resource_id: ResourceId,
|
||||
|
||||
@@ -35,6 +35,7 @@ impl GatewayTunnel {
|
||||
}
|
||||
|
||||
/// Accept a connection request from a client.
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
pub fn accept(
|
||||
&mut self,
|
||||
client_id: ClientId,
|
||||
@@ -249,6 +250,7 @@ impl GatewayState {
|
||||
}
|
||||
|
||||
/// Accept a connection request from a client.
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
pub fn accept(
|
||||
&mut self,
|
||||
client_id: ClientId,
|
||||
|
||||
@@ -709,6 +709,7 @@ impl TunnelTest {
|
||||
c.ipv6_routes = config.ipv6_routes;
|
||||
});
|
||||
}
|
||||
#[expect(deprecated, reason = "Will be deleted together with deprecated API")]
|
||||
ClientEvent::RequestConnection {
|
||||
gateway_id,
|
||||
offer,
|
||||
|
||||
Reference in New Issue
Block a user