feat(snownet): attempt to make new allocation when refresh fails (#3631)

Initially, we thought that we need to replace the entire `Allocation` if
the credentials to the relay change. However, during testing it turned
out that the credentials will change every time the portal sends us new
credentials. Likely, the portal hashes some kind of nonce into the
password as well.

Consequently, throwing away the entire state of the `Allocation` is
wrong. Instead, we will simply try to refresh the allocation using the
new credentials. If the refresh fails, we will try to make a new
allocation. If that also fails unrecoverably, then we "suspend" the
allocation, i.e. the `Allocation` will not perform any further action by
itself.

In case we get a new `refresh` call (which happens every time we want to
use the `Allocation` for a connection), we restart things and try to
make a new one.
This commit is contained in:
Thomas Eizinger
2024-02-15 12:41:10 +11:00
committed by GitHub
parent f42aa862a8
commit 23e89c7290
3 changed files with 205 additions and 65 deletions

View File

@@ -131,8 +131,29 @@ impl Allocation {
.flatten()
}
pub fn uses_credentials(&self, username: &str, password: &str, realm: &str) -> bool {
self.username.name() == username && self.password == password && self.realm.text() == realm
/// Refresh this allocation.
///
/// In case refreshing the allocation fails, we will attempt to make a new one.
pub fn refresh(&mut self, username: Username, password: &str, realm: Realm) {
self.username = username;
self.realm = realm;
self.password = password.to_owned();
if !self.has_allocation() && self.allocate_in_flight() {
tracing::debug!("Not refreshing allocation because we are already making one");
return;
}
if self.is_suspended() {
tracing::debug!("Attempting to make a new allocation");
self.authenticate_and_queue(make_allocate_request());
return;
}
tracing::debug!("Refreshing allocation");
self.authenticate_and_queue(make_refresh_request());
}
#[tracing::instrument(level = "debug", skip(self, packet, now), fields(relay = %self.server, id, method, class, rtt))]
@@ -205,6 +226,9 @@ impl Allocation {
}
match message.method() {
ALLOCATE => {
self.buffered_channel_bindings.clear();
}
CHANNEL_BIND => {
let Some(channel) = original_request
.get_attribute::<ChannelNumber>()
@@ -226,6 +250,9 @@ impl Allocation {
}
self.channel_bindings.clear();
self.allocation_lifetime = None;
self.authenticate_and_queue(make_allocate_request());
}
_ => {}
}
@@ -393,7 +420,7 @@ impl Allocation {
}
if let Some(refresh_at) = self.refresh_allocation_at() {
if now > refresh_at {
if now >= refresh_at {
tracing::debug!("Allocation is due for a refresh");
self.authenticate_and_queue(make_refresh_request());
}
@@ -525,14 +552,6 @@ impl Allocation {
})
}
pub fn refresh(&mut self) {
if !self.has_allocation() {
return;
}
self.authenticate_and_queue(make_refresh_request());
}
fn has_allocation(&self) -> bool {
self.ip4_allocation.is_some() || self.ip6_allocation.is_some()
}
@@ -552,6 +571,24 @@ impl Allocation {
})
}
fn allocate_in_flight(&self) -> bool {
self.sent_requests
.values()
.any(|(r, _, _)| r.method() == ALLOCATE)
}
/// Check whether this allocation is suspended.
///
/// We call it suspended if we have given up making an allocation due to some error.
fn is_suspended(&self) -> bool {
let no_allocation = !self.has_allocation();
let nothing_in_flight = self.sent_requests.is_empty();
let nothing_buffered = self.buffered_transmits.is_empty();
let waiting_on_nothing = self.poll_timeout().is_none();
no_allocation && nothing_in_flight && nothing_buffered && waiting_on_nothing
}
fn authenticate(&self, message: Message<Attribute>) -> Message<Attribute> {
let attributes = message
.attributes()
@@ -949,6 +986,10 @@ impl BufferedChannelBindings {
fn pop_front(&mut self) -> Option<Message<Attribute>> {
self.inner.pop_front()
}
fn clear(&mut self) {
self.inner.clear()
}
}
#[cfg(test)]
@@ -958,7 +999,11 @@ mod tests {
iter,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
use stun_codec::{rfc5389::errors::BadRequest, rfc5766::errors::AllocationMismatch, Message};
use stun_codec::{
rfc5389::errors::{BadRequest, ServerError},
rfc5766::errors::AllocationMismatch,
Message,
};
const PEER1: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 10000);
@@ -971,6 +1016,8 @@ mod tests {
const MINUTE: Duration = Duration::from_secs(60);
const ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
#[test]
fn returns_first_available_channel() {
let mut channel_bindings = ChannelBindings::default();
@@ -1419,7 +1466,7 @@ mod tests {
}
#[test]
fn calling_refresh_will_trigger_refresh() {
fn calling_refresh_with_same_credentials_will_trigger_refresh() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
@@ -1428,10 +1475,13 @@ mod tests {
Instant::now(),
);
allocation.refresh();
allocation.refresh_with_same_credentials();
let refresh = allocation.next_message().unwrap();
assert_eq!(refresh.method(), REFRESH);
let lifetime = refresh.get_attribute::<Lifetime>();
assert!(lifetime.is_none() || lifetime.is_some_and(|l| l.lifetime() != Duration::ZERO));
}
#[test]
@@ -1445,7 +1495,7 @@ mod tests {
);
let _ = iter::from_fn(|| allocation.poll_event()).collect::<Vec<_>>(); // Drain events.
allocation.refresh();
allocation.refresh_with_same_credentials();
let refresh = allocation.next_message().unwrap();
allocation.handle_test_input(&failed_refresh(&refresh), Instant::now());
@@ -1490,7 +1540,7 @@ mod tests {
let msg = allocation.encode_to_vec(PEER2_IP4, b"foobar", Instant::now());
assert!(msg.is_some(), "expect to have a channel to peer");
allocation.refresh();
allocation.refresh_with_same_credentials();
let refresh = allocation.next_message().unwrap();
allocation.handle_test_input(&failed_refresh(&refresh), Instant::now());
@@ -1505,12 +1555,117 @@ mod tests {
let _allocate = allocation.next_message().unwrap();
allocation.refresh();
allocation.refresh_with_same_credentials();
let next_msg = allocation.next_message();
assert!(next_msg.is_none())
}
#[test]
fn failed_refresh_attempts_to_make_new_allocation() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(
&allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]),
Instant::now(),
);
allocation.refresh_with_same_credentials();
let refresh = allocation.next_message().unwrap();
allocation.handle_test_input(&failed_refresh(&refresh), Instant::now());
let allocate = allocation.next_message().unwrap();
assert_eq!(allocate.method(), ALLOCATE);
}
#[test]
fn allocation_is_refreshed_after_half_its_lifetime() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
let received_at = Instant::now();
allocation.handle_test_input(
&allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]),
received_at,
);
let refresh_at = allocation.poll_timeout().unwrap();
assert_eq!(refresh_at, received_at + (ALLOCATION_LIFETIME / 2));
allocation.handle_timeout(refresh_at);
let next_msg = allocation.next_message().unwrap();
assert_eq!(next_msg.method(), REFRESH)
}
#[test]
fn failed_refresh_resets_allocation_lifetime() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(
&allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]),
Instant::now(),
);
allocation.advance_to_next_timeout();
let refresh = allocation.next_message().unwrap();
allocation.handle_test_input(&failed_refresh(&refresh), Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(&server_error(&allocate), Instant::now()); // These ones are not retried.
assert_eq!(allocation.poll_timeout(), None);
}
#[test]
fn when_refreshed_with_no_allocation_after_failed_response_tries_to_allocate() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(&server_error(&allocate), Instant::now());
allocation.refresh_with_same_credentials();
let next_msg = allocation.next_message().unwrap();
assert_eq!(next_msg.method(), ALLOCATE)
}
#[test]
fn failed_allocation_clears_buffered_channel_bindings() {
let mut allocation = Allocation::for_test(Instant::now());
allocation.bind_channel(PEER1, Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings.
allocation.refresh_with_same_credentials();
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(
&allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]),
Instant::now(),
);
let next_msg = allocation.next_message();
assert!(next_msg.is_none())
}
#[test]
fn failed_allocation_is_suspended() {
let mut allocation = Allocation::for_test(Instant::now());
let allocate = allocation.next_message().unwrap();
allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings.
assert!(allocation.is_suspended())
}
fn ch(peer: SocketAddr, now: Instant) -> Channel {
Channel {
peer,
@@ -1533,7 +1688,7 @@ mod tests {
message.add_attribute(XorRelayAddress::new(*addr));
}
message.add_attribute(Lifetime::new(Duration::from_secs(600)).unwrap());
message.add_attribute(Lifetime::new(ALLOCATION_LIFETIME).unwrap());
encode(message)
}
@@ -1551,6 +1706,17 @@ mod tests {
encode(message)
}
fn server_error(request: &Message<Attribute>) -> Vec<u8> {
let mut message = Message::new(
MessageClass::ErrorResponse,
request.method(),
request.transaction_id(),
);
message.add_attribute(ErrorCode::from(ServerError));
encode(message)
}
fn stale_nonce_response(request: &Message<Attribute>, nonce: Nonce) -> Vec<u8> {
let mut message = Message::new(
MessageClass::ErrorResponse,
@@ -1622,5 +1788,19 @@ mod tests {
fn handle_test_input(&mut self, packet: &[u8], now: Instant) {
self.handle_input(RELAY, PEER1, packet, now);
}
fn advance_to_next_timeout(&mut self) {
if let Some(next) = self.poll_timeout() {
self.handle_timeout(next)
}
}
fn refresh_with_same_credentials(&mut self) {
self.refresh(
Username::new("foobar".to_owned()).unwrap(),
"baz",
Realm::new("firezone".to_owned()).unwrap(),
);
}
}
}

View File

@@ -814,14 +814,6 @@ where
fn upsert_turn_servers(&mut self, servers: &HashSet<(SocketAddr, String, String, String)>) {
for (server, username, password, realm) in servers {
if let Some(existing) = self.allocations.get_mut(server) {
if existing.uses_credentials(username, password, realm) {
existing.refresh();
continue;
}
}
let Ok(username) = Username::new(username.to_owned()) else {
tracing::debug!(%username, "Invalid TURN username");
continue;
@@ -831,16 +823,17 @@ where
continue;
};
let existing = self.allocations.insert(
if let Some(existing) = self.allocations.get_mut(server) {
existing.refresh(username, password, realm);
continue;
}
self.allocations.insert(
*server,
Allocation::new(*server, username, password.clone(), realm, self.last_now),
);
if existing.is_some() {
tracing::info!(address = %server, "Replaced existing allocation because credentials to TURN server changed");
} else {
tracing::info!(address = %server, "Added new TURN server");
}
tracing::info!(address = %server, "Added new TURN server");
}
}

View File

@@ -36,39 +36,6 @@ fn answer_after_stale_connection_does_not_panic() {
alice.accept_answer(1, bob.public_key(), answer);
}
#[test]
fn reinitialize_allocation_if_credentials_for_relay_differ() {
let mut alice = ClientNode::<u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
Instant::now(),
);
// Make a new connection that uses RELAY with initial set of credentials
let _ = alice.new_connection(
1,
HashSet::new(),
HashSet::from([relay("user1", "pass1", "realm1")]),
);
let transmit = alice.poll_transmit().unwrap();
assert_eq!(transmit.dst, RELAY);
assert!(alice.poll_transmit().is_none());
// Make another connection, using the same relay but different credentials (happens when the relay restarts)
let _ = alice.new_connection(
2,
HashSet::new(),
HashSet::from([relay("user2", "pass2", "realm1")]),
);
// Expect to send another message to the "new" relay
let transmit = alice.poll_transmit().unwrap();
assert_eq!(transmit.dst, RELAY);
assert_eq!(&transmit.payload[..2], [0x0, 0x3]); // `ALLOCATE` is 0x0003: https://www.rfc-editor.org/rfc/rfc8656#name-stun-methods
assert!(alice.poll_transmit().is_none());
}
#[test]
fn second_connection_with_same_relay_reuses_allocation() {
let mut alice = ClientNode::<u64>::new(