feat(snownet): reduce connection setup latency (#3772)

Previously, we pretty much always lost the wireguard handshake packet,
causing us to wait for the rekey-timeout before we try again.

We can fix this by first checking that we actually have a socket that we
can send the encapsulated packet on. Additionally, we can directly force
a wireguard handshake as soon as we discover the first socket to the
remote.

This reduces the setup latency to ~3 seconds in my testing.

Resolves: #3779.
This commit is contained in:
Thomas Eizinger
2024-02-28 11:15:17 +11:00
committed by GitHub
parent beb5c3834d
commit ca0839d072

View File

@@ -193,6 +193,7 @@ where
///
/// To do that, we need to check all candidates of each allocation and compare their IP.
/// The same relay might be reachable over IPv4 and IPv6.
#[must_use]
fn same_relay_as_peer(&mut self, id: TId, candidate: &Candidate) -> Option<&mut Allocation> {
self.allocations
.iter_mut()
@@ -272,6 +273,9 @@ where
.get_established_mut(&connection)
.ok_or(Error::NotConnected)?;
// Must bail early if we don't have a socket yet to avoid running into WG timeouts.
let socket = conn.peer_socket.ok_or(Error::NotConnected)?;
let (header, payload) = self.buffer.as_mut().split_at_mut(4);
let packet_len = match conn.tunnel.encapsulate(packet.packet(), payload) {
@@ -285,7 +289,7 @@ where
let packet = &payload[..packet_len];
match conn.peer_socket.ok_or(Error::NotConnected)? {
match socket {
PeerSocket::Direct {
dest: remote,
source,
@@ -321,6 +325,7 @@ where
}
/// Returns a pending [`Event`] from the pool.
#[must_use]
pub fn poll_event(&mut self) -> Option<Event<TId>> {
let binding_events = self.bindings.iter_mut().flat_map(|(server, binding)| {
iter::from_fn(|| binding.poll_event().map(|e| (*server, e)))
@@ -411,6 +416,12 @@ where
conn.peer_socket = Some(remote_socket);
if is_first_connection {
tracing::info!(%id, "Starting wireguard handshake");
self.buffered_transmits.extend(
conn.force_handshake(&mut self.allocations, self.last_now),
);
return Some(Event::ConnectionEstablished(id));
}
}
@@ -434,6 +445,7 @@ where
///
/// This function only takes `&mut self` because it caches certain computations internally.
/// The returned timestamp will **not** change unless other state is modified.
#[must_use]
pub fn poll_timeout(&mut self) -> Option<Instant> {
let mut connection_timeout = None;
@@ -513,6 +525,7 @@ where
}
/// Returns buffered data that needs to be sent on the socket.
#[must_use]
pub fn poll_transmit(&mut self) -> Option<Transmit<'static>> {
for (_, conn) in self.connections.iter_established_mut() {
if let Some(transmit) = conn.poll_transmit(&mut self.allocations, self.last_now) {
@@ -535,6 +548,7 @@ where
self.buffered_transmits.pop_front()
}
#[must_use]
fn init_connection(
&mut self,
mut agent: IceAgent,
@@ -588,6 +602,7 @@ where
Ok(())
}
#[must_use]
fn bindings_try_handle(
&mut self,
from: SocketAddr,
@@ -609,6 +624,7 @@ where
}
/// Tries to handle the packet using one of our [`Allocation`]s.
#[must_use]
fn allocations_try_handle<'p>(
&mut self,
from: SocketAddr,
@@ -645,6 +661,7 @@ where
}
}
#[must_use]
fn agents_try_handle(
&mut self,
from: SocketAddr,
@@ -677,6 +694,7 @@ where
}))
}
#[must_use]
fn connections_try_handle<'b>(
&mut self,
from: SocketAddr,
@@ -762,6 +780,7 @@ where
/// Out of all configured STUN and TURN servers, the connection will only use the ones provided here.
/// The returned [`Offer`] must be passed to the remote via a signalling channel.
#[tracing::instrument(level = "info", skip_all, fields(%id))]
#[must_use]
pub fn new_connection(
&mut self,
id: TId,
@@ -812,11 +831,13 @@ where
debug_assert!(existing.is_none());
tracing::info!("Establishing new connection");
params
}
/// Accept an [`Answer`] from the remote for a connection previously created via [`Node::new_connection`].
#[tracing::instrument(level = "debug", skip_all, fields(%id, remote = %hex::encode(remote.as_bytes())))]
#[tracing::instrument(level = "info", skip_all, fields(%id))]
pub fn accept_answer(&mut self, id: TId, remote: PublicKey, answer: Answer) {
let Some(initial) = self.connections.initial.remove(&id) else {
tracing::debug!("No initial connection state, ignoring answer"); // This can happen if the connection setup timed out.
@@ -846,6 +867,8 @@ where
let existing = self.connections.established.insert(id, connection);
tracing::info!(remote = %hex::encode(remote.as_bytes()), "Signalling protocol completed");
debug_assert!(existing.is_none());
}
}
@@ -859,6 +882,7 @@ where
/// Out of all configured STUN and TURN servers, the connection will only use the ones provided here.
/// The returned [`Answer`] must be passed to the remote via a signalling channel.
#[tracing::instrument(level = "info", skip_all, fields(%id))]
#[must_use]
pub fn accept_connection(
&mut self,
id: TId,
@@ -916,6 +940,8 @@ where
debug_assert!(existing.is_none());
tracing::info!("Created new connection");
answer
}
}
@@ -1271,6 +1297,7 @@ impl Connection {
///
/// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete.
/// We already want to accept that traffic and not throw it away.
#[must_use]
fn accepts(&self, addr: SocketAddr) -> bool {
let from_connected_remote = self.peer_socket.is_some_and(|r| match r {
PeerSocket::Direct { dest, .. } => dest == addr,
@@ -1304,6 +1331,7 @@ impl Connection {
}
}
#[must_use]
fn poll_timeout(&mut self) -> Option<Instant> {
let agent_timeout = self.agent.poll_timeout();
let next_wg_timer = Some(self.next_timer_update);
@@ -1311,6 +1339,7 @@ impl Connection {
earliest(agent_timeout, next_wg_timer)
}
#[must_use]
fn poll_transmit(
&mut self,
allocations: &mut HashMap<SocketAddr, Allocation>,
@@ -1392,6 +1421,7 @@ impl Connection {
Ok(None)
}
#[must_use]
fn encapsulate(
&self,
message: &[u8],
@@ -1412,4 +1442,26 @@ impl Connection {
}
}
}
#[must_use]
fn force_handshake(
&mut self,
allocations: &mut HashMap<SocketAddr, Allocation>,
now: Instant,
) -> Option<Transmit<'static>> {
/// [`boringtun`] requires us to pass buffers in where it can construct its packets.
///
/// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`].
const MAX_SCRATCH_SPACE: usize = 148;
let mut buf = [0u8; MAX_SCRATCH_SPACE];
let TunnResult::WriteToNetwork(bytes) =
self.tunnel.format_handshake_initiation(&mut buf, true)
else {
return None;
};
self.encapsulate(bytes, allocations, now)
}
}