feat(snownet): retry TURN allocations using exponential backoffs (#3530)

Similar to https://github.com/firezone/firezone/pull/3529.
This commit is contained in:
Thomas Eizinger
2024-02-06 18:12:23 +11:00
committed by GitHub
parent 6fcfc5497d
commit 75732ca56a
6 changed files with 229 additions and 92 deletions

View File

@@ -1,4 +1,9 @@
use crate::node::Transmit;
use crate::{
backoff::{self, ExponentialBackoff},
node::Transmit,
utils::earliest,
};
use ::backoff::backoff::Backoff;
use bytecodec::{DecodeExt as _, EncodeExt as _};
use rand::random;
use std::{
@@ -45,12 +50,13 @@ pub struct Allocation {
buffered_transmits: VecDeque<Transmit<'static>>,
new_candidates: VecDeque<Candidate>,
sent_requests: HashMap<TransactionId, (Message<Attribute>, Instant)>,
backoff: ExponentialBackoff,
sent_requests: HashMap<TransactionId, (Message<Attribute>, Instant, Duration)>,
channel_bindings: ChannelBindings,
buffered_channel_bindings: VecDeque<Message<Attribute>>,
buffered_channel_bindings: BufferedChannelBindings,
last_now: Option<Instant>,
last_now: Instant,
username: Username,
password: String,
@@ -59,8 +65,14 @@ pub struct Allocation {
}
impl Allocation {
pub fn new(server: SocketAddr, username: Username, password: String, realm: Realm) -> Self {
Self {
pub fn new(
server: SocketAddr,
username: Username,
password: String,
realm: Realm,
now: Instant,
) -> Self {
let mut allocation = Self {
server,
last_srflx_candidate: Default::default(),
ip4_allocation: Default::default(),
@@ -74,9 +86,16 @@ impl Allocation {
nonce: Default::default(),
allocation_lifetime: Default::default(),
channel_bindings: Default::default(),
last_now: Default::default(),
last_now: now,
buffered_channel_bindings: Default::default(),
}
backoff: backoff::new(now, REQUEST_TIMEOUT),
};
tracing::debug!(%server, "Requesting new allocation");
allocation.authenticate_and_queue(make_allocate_request());
allocation
}
pub fn current_candidates(&self) -> impl Iterator<Item = Candidate> {
@@ -96,9 +115,7 @@ impl Allocation {
packet: &[u8],
now: Instant,
) -> bool {
if Some(now) > self.last_now {
self.last_now = Some(now);
}
self.update_now(now);
if from != self.server {
return false;
@@ -108,12 +125,14 @@ impl Allocation {
return false;
};
let Some((original_request, sent_at)) =
let Some((original_request, sent_at, _)) =
self.sent_requests.remove(&message.transaction_id())
else {
return false;
};
self.backoff.reset();
let rtt = now.duration_since(sent_at);
tracing::debug!(id = ?original_request.transaction_id(), method = %original_request.method(), ?rtt);
@@ -137,7 +156,7 @@ impl Allocation {
"Request failed, re-authenticating"
);
self.authenticate_and_queue(original_request, now);
self.authenticate_and_queue(original_request);
return true;
}
@@ -225,7 +244,7 @@ impl Allocation {
);
while let Some(buffered) = self.buffered_channel_bindings.pop_front() {
self.authenticate_and_queue(buffered, now);
self.authenticate_and_queue(buffered);
}
}
REFRESH => {
@@ -295,37 +314,29 @@ impl Allocation {
}
pub fn handle_timeout(&mut self, now: Instant) {
if Some(now) > self.last_now {
self.last_now = Some(now);
}
if !self.has_allocation() && !self.allocate_in_flight() {
tracing::debug!(server = %self.server, "Request new allocation");
self.authenticate_and_queue(make_allocate_request(), now);
}
self.update_now(now);
while let Some(timed_out_request) =
self.sent_requests.iter().find_map(|(id, (_, sent_at))| {
(now.duration_since(*sent_at) >= REQUEST_TIMEOUT).then_some(*id)
})
self.sent_requests
.iter()
.find_map(|(id, (_, sent_at, backoff))| {
(now.duration_since(*sent_at) >= *backoff).then_some(*id)
})
{
let (request, _) = self
let (request, _, _) = self
.sent_requests
.remove(&timed_out_request)
.expect("ID is from list");
tracing::debug!(id = ?request.transaction_id(), method = %request.method(), "Request timed out, re-sending");
self.authenticate_and_queue(request, now);
self.authenticate_and_queue(request);
}
if let Some((received_at, lifetime)) = self.allocation_lifetime {
let refresh_after = lifetime / 2;
if now > received_at + refresh_after {
tracing::debug!("Allocation is at 50% of its lifetime, refreshing");
self.authenticate_and_queue(make_refresh_request(), now);
if let Some(refresh_at) = self.refresh_allocation_at() {
if now > refresh_at {
tracing::debug!("Allocation is due for a refresh");
self.authenticate_and_queue(make_refresh_request());
}
}
@@ -336,7 +347,7 @@ impl Allocation {
.collect::<Vec<_>>(); // Need to allocate here to satisfy borrow-checker. Number of channel refresh messages should be small so this shouldn't be a big impact.
for message in refresh_messages {
self.authenticate_and_queue(message, now);
self.authenticate_and_queue(message);
}
// TODO: Clean up unused channels
@@ -350,11 +361,19 @@ impl Allocation {
self.buffered_transmits.pop_front()
}
// pub fn poll_timeout(&self) -> Option<Instant> {
// None // TODO: Implement this.
// }
pub fn poll_timeout(&self) -> Option<Instant> {
let mut earliest_timeout = self.refresh_allocation_at();
for (_, (_, sent_at, backoff)) in self.sent_requests.iter() {
earliest_timeout = earliest(earliest_timeout, Some(*sent_at + *backoff));
}
earliest_timeout
}
pub fn bind_channel(&mut self, peer: SocketAddr, now: Instant) {
self.update_now(now);
if self.channel_bindings.channel_to_peer(peer, now).is_some() {
tracing::debug!(relay = %self.server, %peer, "Already got a channel");
return;
@@ -374,7 +393,7 @@ impl Allocation {
return;
}
self.authenticate_and_queue(msg, now);
self.authenticate_and_queue(msg);
}
pub fn encode_to_slice(
@@ -403,18 +422,20 @@ impl Allocation {
Some(channel_data)
}
fn refresh_allocation_at(&self) -> Option<Instant> {
let (received_at, lifetime) = self.allocation_lifetime?;
let refresh_after = lifetime / 2;
Some(received_at + refresh_after)
}
fn has_allocation(&self) -> bool {
self.ip4_allocation.is_some() || self.ip6_allocation.is_some()
}
fn allocate_in_flight(&self) -> bool {
self.sent_requests
.values()
.any(|(r, _)| r.method() == ALLOCATE)
}
fn channel_binding_in_flight(&self, channel: u16) -> bool {
self.sent_requests.values().any(|(r, _)| {
self.sent_requests.values().any(|(r, _, _)| {
r.method() == BINDING
&& r.get_attribute::<ChannelNumber>()
.is_some_and(|n| n.value() == channel)
@@ -455,18 +476,33 @@ impl Allocation {
message
}
fn authenticate_and_queue(&mut self, message: Message<Attribute>, now: Instant) {
fn authenticate_and_queue(&mut self, message: Message<Attribute>) {
let Some(backoff) = self.backoff.next_backoff() else {
tracing::warn!(
"Unable to queue {} because we've exceeded our backoffs",
message.method()
);
return;
};
let authenticated_message = self.authenticate(message);
let id = authenticated_message.transaction_id();
self.sent_requests
.insert(id, (authenticated_message.clone(), now));
.insert(id, (authenticated_message.clone(), self.last_now, backoff));
self.buffered_transmits.push_back(Transmit {
src: None,
dst: self.server,
payload: encode(authenticated_message).into(),
});
}
fn update_now(&mut self, now: Instant) {
if now > self.last_now {
self.last_now = now;
self.backoff.clock.now = now;
}
}
}
fn update_candidate(
@@ -769,6 +805,35 @@ impl Channel {
}
}
#[derive(Debug, Default)]
struct BufferedChannelBindings {
inner: VecDeque<Message<Attribute>>,
}
impl BufferedChannelBindings {
/// Adds a new `CHANNEL-BIND` message to this buffer.
///
/// The buffer has a fixed size of 10 to avoid unbounded memory growth.
/// All prior messages are cleared once we outgrow the buffer.
/// Very likely, we buffer `CHANNEL-BIND` messages only for a brief period of time.
/// However, it might also happen that we can only re-connect to a TURN server after an extended period of downtime.
/// Chances are that we don't need any of the old channels any more, and that the new ones are much more relevant.
fn push_back(&mut self, msg: Message<Attribute>) {
debug_assert_eq!(msg.method(), CHANNEL_BIND);
if self.inner.len() == 10 {
tracing::debug!("Clearing buffered channel-data messages");
self.inner.clear()
}
self.inner.push_back(msg);
}
fn pop_front(&mut self) -> Option<Message<Attribute>> {
self.inner.pop_front()
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -954,15 +1019,19 @@ mod tests {
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),
Realm::new("firezone".to_owned()).unwrap(),
Instant::now(),
);
let allocate = next_stun_message(&mut allocation).unwrap();
assert_eq!(allocate.method(), ALLOCATE);
allocation.bind_channel(PEER1, Instant::now());
assert!(
next_stun_message(&mut allocation).is_none(),
"no messages to be sent if we don't have an allocation"
);
make_allocation(&mut allocation, PEER1);
make_allocation(&mut allocation, allocate.transaction_id(), PEER1);
let message = next_stun_message(&mut allocation).unwrap();
assert_eq!(message.method(), CHANNEL_BIND);
@@ -975,9 +1044,11 @@ mod tests {
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),
Realm::new("firezone".to_owned()).unwrap(),
Instant::now(),
);
make_allocation(&mut allocation, PEER1);
let allocate = next_stun_message(&mut allocation).unwrap();
make_allocation(&mut allocation, allocate.transaction_id(), PEER1);
allocation.bind_channel(PEER2, Instant::now());
let message = allocation.encode_to_vec(PEER2, b"foobar", Instant::now());
@@ -992,9 +1063,11 @@ mod tests {
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),
Realm::new("firezone".to_owned()).unwrap(),
Instant::now(),
);
make_allocation(&mut allocation, PEER1);
let allocate = next_stun_message(&mut allocation).unwrap();
make_allocation(&mut allocation, allocate.transaction_id(), PEER1);
allocation.bind_channel(PEER2, Instant::now());
let channel_bind_msg = next_stun_message(&mut allocation).unwrap();
@@ -1023,9 +1096,11 @@ mod tests {
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),
Realm::new("firezone".to_owned()).unwrap(),
Instant::now(),
);
make_allocation(&mut allocation, PEER1);
let allocate = next_stun_message(&mut allocation).unwrap();
make_allocation(&mut allocation, allocate.transaction_id(), PEER1);
allocation.bind_channel(PEER2, Instant::now());
let channel_bind_msg = next_stun_message(&mut allocation).unwrap();
@@ -1042,6 +1117,57 @@ mod tests {
assert!(next_msg.is_none())
}
#[test]
fn retries_requests_using_backoff_and_gives_up_eventually() {
let start = Instant::now();
let mut allocation = Allocation::new(
RELAY,
Username::new("foobar".to_owned()).unwrap(),
"baz".to_owned(),
Realm::new("firezone".to_owned()).unwrap(),
start,
);
let mut expected_backoffs = VecDeque::from(backoff::steps(start));
loop {
let Some(timeout) = allocation.poll_timeout() else {
break;
};
assert_eq!(expected_backoffs.pop_front().unwrap(), timeout);
assert!(allocation.poll_transmit().is_some());
assert!(allocation.poll_transmit().is_none());
allocation.handle_timeout(timeout);
}
assert!(expected_backoffs.is_empty())
}
#[test]
fn discards_old_channel_bindings_once_we_outgrow_buffer() {
let mut buffered_channel_bindings = BufferedChannelBindings::default();
for c in 0..11 {
buffered_channel_bindings.push_back(make_channel_bind_request(
PEER1,
ChannelBindings::FIRST_CHANNEL + c,
));
}
let msg = buffered_channel_bindings.pop_front().unwrap();
assert!(
buffered_channel_bindings.pop_front().is_none(),
"no more messages"
);
assert_eq!(
msg.get_attribute::<ChannelNumber>().unwrap().value(),
ChannelBindings::FIRST_CHANNEL + 10
);
}
fn ch(peer: SocketAddr, now: Instant) -> Channel {
Channel {
peer,
@@ -1057,13 +1183,11 @@ mod tests {
Some(decode(&transmit.payload).unwrap().unwrap())
}
fn make_allocation(allocation: &mut Allocation, local: SocketAddr) {
allocation.handle_timeout(Instant::now());
let message = next_stun_message(allocation).unwrap();
fn make_allocation(allocation: &mut Allocation, allocate_id: TransactionId, local: SocketAddr) {
allocation.handle_input(
RELAY,
local,
&encode(allocate_response(message.transaction_id())),
&encode(allocate_response(allocate_id)),
Instant::now(),
);
}

View File

@@ -30,3 +30,35 @@ pub fn new(
clock: ManualClock { now },
}
}
/// Calculates our backoff times, starting from the given [`Instant`].
///
/// The current strategy is multiplying the previous interval by 1.5 and adding them up.
#[cfg(test)]
pub fn steps(start: Instant) -> [Instant; 19] {
fn secs(secs: f64) -> Duration {
Duration::from_micros((secs * 1_000_000.0) as u64)
}
[
start + secs(5.0),
start + secs(5.0 + 7.5),
start + secs(5.0 + 7.5 + 11.25),
start + secs(5.0 + 7.5 + 11.25 + 16.875),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 1.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 2.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 3.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 4.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 5.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 6.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 7.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 8.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 9.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 10.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 11.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 12.0),
]
}

View File

@@ -8,6 +8,7 @@ mod info;
mod ip_packet;
mod node;
mod stun_binding;
mod utils;
pub use info::ConnectionInfo;
pub use ip_packet::{IpPacket, MutableIpPacket};

View File

@@ -23,6 +23,7 @@ use crate::allocation::Allocation;
use crate::index::IndexLfsr;
use crate::info::ConnectionInfo;
use crate::stun_binding::StunBinding;
use crate::utils::earliest;
use crate::{IpPacket, MutableIpPacket};
use boringtun::noise::errors::WireGuardError;
use std::borrow::Cow;
@@ -450,6 +451,9 @@ where
for b in self.bindings.values_mut() {
connection_timeout = earliest(connection_timeout, b.poll_timeout());
}
for a in self.allocations.values_mut() {
connection_timeout = earliest(connection_timeout, a.poll_timeout());
}
earliest(connection_timeout, self.next_rate_limiter_reset)
}
@@ -809,7 +813,7 @@ where
self.allocations.insert(
*server,
Allocation::new(*server, username, password.clone(), realm),
Allocation::new(*server, username, password.clone(), realm, self.last_now),
);
}
}
@@ -1168,12 +1172,3 @@ impl Connection {
}
}
}
fn earliest(left: Option<Instant>, right: Option<Instant>) -> Option<Instant> {
match (left, right) {
(None, None) => None,
(Some(left), Some(right)) => Some(std::cmp::min(left, right)),
(Some(left), None) => Some(left),
(None, Some(right)) => Some(right),
}
}

View File

@@ -315,32 +315,7 @@ mod tests {
let start = Instant::now();
let mut stun_binding = StunBinding::new(SERVER1, start);
fn secs(secs: f64) -> Duration {
Duration::from_micros((secs * 1_000_000.0) as u64)
}
// The backoff strategy is to increment the previous interval by 1.5
let mut expected_backoffs = VecDeque::from([
start + secs(5.0),
start + secs(5.0 + 7.5),
start + secs(5.0 + 7.5 + 11.25),
start + secs(5.0 + 7.5 + 11.25 + 16.875),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 1.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 2.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 3.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 4.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 5.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 6.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 7.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 8.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 9.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 10.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 11.0),
start + secs(5.0 + 7.5 + 11.25 + 16.875 + 25.3125 + 37.96875 + 56.953125 + 60.0 * 12.0),
]);
let mut expected_backoffs = VecDeque::from(backoff::steps(start));
loop {
let Some(timeout) = stun_binding.poll_timeout() else {

View File

@@ -0,0 +1,10 @@
use std::time::Instant;
pub fn earliest(left: Option<Instant>, right: Option<Instant>) -> Option<Instant> {
match (left, right) {
(None, None) => None,
(Some(left), Some(right)) => Some(std::cmp::min(left, right)),
(Some(left), None) => Some(left),
(None, Some(right)) => Some(right),
}
}