From 23e89c7290178182ac93bef7330d99cbfb0e2044 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 15 Feb 2024 12:41:10 +1100 Subject: [PATCH] feat(snownet): attempt to make new allocation when refresh fails (#3631) Initially, we thought that we need to replace the entire `Allocation` if the credentials to the relay change. However, during testing it turned out that the credentials will change every time the portal sends us new credentials. Likely, the portal hashes some kind of nonce into the password as well. Consequently, throwing away the entire state of the `Allocation` is wrong. Instead, we will simply try to refresh the allocation using the new credentials. If the refresh fails, we will try to make a new allocation. If that also fails unrecoverably, then we "suspend" the allocation, i.e. the `Allocation` will not perform any further action by itself. In case we get a new `refresh` call (which happens every time we want to use the `Allocation` for a connection), we restart things and try to make a new one. --- rust/connlib/snownet/src/allocation.rs | 216 ++++++++++++++++++++++--- rust/connlib/snownet/src/node.rs | 21 +-- rust/connlib/snownet/tests/lib.rs | 33 ---- 3 files changed, 205 insertions(+), 65 deletions(-) diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 9ea5c3d87..3ce2aabf9 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -131,8 +131,29 @@ impl Allocation { .flatten() } - pub fn uses_credentials(&self, username: &str, password: &str, realm: &str) -> bool { - self.username.name() == username && self.password == password && self.realm.text() == realm + /// Refresh this allocation. + /// + /// In case refreshing the allocation fails, we will attempt to make a new one. + pub fn refresh(&mut self, username: Username, password: &str, realm: Realm) { + self.username = username; + self.realm = realm; + self.password = password.to_owned(); + + if !self.has_allocation() && self.allocate_in_flight() { + tracing::debug!("Not refreshing allocation because we are already making one"); + return; + } + + if self.is_suspended() { + tracing::debug!("Attempting to make a new allocation"); + + self.authenticate_and_queue(make_allocate_request()); + return; + } + + tracing::debug!("Refreshing allocation"); + + self.authenticate_and_queue(make_refresh_request()); } #[tracing::instrument(level = "debug", skip(self, packet, now), fields(relay = %self.server, id, method, class, rtt))] @@ -205,6 +226,9 @@ impl Allocation { } match message.method() { + ALLOCATE => { + self.buffered_channel_bindings.clear(); + } CHANNEL_BIND => { let Some(channel) = original_request .get_attribute::() @@ -226,6 +250,9 @@ impl Allocation { } self.channel_bindings.clear(); + self.allocation_lifetime = None; + + self.authenticate_and_queue(make_allocate_request()); } _ => {} } @@ -393,7 +420,7 @@ impl Allocation { } if let Some(refresh_at) = self.refresh_allocation_at() { - if now > refresh_at { + if now >= refresh_at { tracing::debug!("Allocation is due for a refresh"); self.authenticate_and_queue(make_refresh_request()); } @@ -525,14 +552,6 @@ impl Allocation { }) } - pub fn refresh(&mut self) { - if !self.has_allocation() { - return; - } - - self.authenticate_and_queue(make_refresh_request()); - } - fn has_allocation(&self) -> bool { self.ip4_allocation.is_some() || self.ip6_allocation.is_some() } @@ -552,6 +571,24 @@ impl Allocation { }) } + fn allocate_in_flight(&self) -> bool { + self.sent_requests + .values() + .any(|(r, _, _)| r.method() == ALLOCATE) + } + + /// Check whether this allocation is suspended. + /// + /// We call it suspended if we have given up making an allocation due to some error. + fn is_suspended(&self) -> bool { + let no_allocation = !self.has_allocation(); + let nothing_in_flight = self.sent_requests.is_empty(); + let nothing_buffered = self.buffered_transmits.is_empty(); + let waiting_on_nothing = self.poll_timeout().is_none(); + + no_allocation && nothing_in_flight && nothing_buffered && waiting_on_nothing + } + fn authenticate(&self, message: Message) -> Message { let attributes = message .attributes() @@ -949,6 +986,10 @@ impl BufferedChannelBindings { fn pop_front(&mut self) -> Option> { self.inner.pop_front() } + + fn clear(&mut self) { + self.inner.clear() + } } #[cfg(test)] @@ -958,7 +999,11 @@ mod tests { iter, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; - use stun_codec::{rfc5389::errors::BadRequest, rfc5766::errors::AllocationMismatch, Message}; + use stun_codec::{ + rfc5389::errors::{BadRequest, ServerError}, + rfc5766::errors::AllocationMismatch, + Message, + }; const PEER1: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 10000); @@ -971,6 +1016,8 @@ mod tests { const MINUTE: Duration = Duration::from_secs(60); + const ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); + #[test] fn returns_first_available_channel() { let mut channel_bindings = ChannelBindings::default(); @@ -1419,7 +1466,7 @@ mod tests { } #[test] - fn calling_refresh_will_trigger_refresh() { + fn calling_refresh_with_same_credentials_will_trigger_refresh() { let mut allocation = Allocation::for_test(Instant::now()); let allocate = allocation.next_message().unwrap(); @@ -1428,10 +1475,13 @@ mod tests { Instant::now(), ); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); assert_eq!(refresh.method(), REFRESH); + + let lifetime = refresh.get_attribute::(); + assert!(lifetime.is_none() || lifetime.is_some_and(|l| l.lifetime() != Duration::ZERO)); } #[test] @@ -1445,7 +1495,7 @@ mod tests { ); let _ = iter::from_fn(|| allocation.poll_event()).collect::>(); // Drain events. - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); @@ -1490,7 +1540,7 @@ mod tests { let msg = allocation.encode_to_vec(PEER2_IP4, b"foobar", Instant::now()); assert!(msg.is_some(), "expect to have a channel to peer"); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); @@ -1505,12 +1555,117 @@ mod tests { let _allocate = allocation.next_message().unwrap(); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let next_msg = allocation.next_message(); assert!(next_msg.is_none()) } + #[test] + fn failed_refresh_attempts_to_make_new_allocation() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + allocation.refresh_with_same_credentials(); + + let refresh = allocation.next_message().unwrap(); + allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); + + let allocate = allocation.next_message().unwrap(); + assert_eq!(allocate.method(), ALLOCATE); + } + + #[test] + fn allocation_is_refreshed_after_half_its_lifetime() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + + let received_at = Instant::now(); + + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + received_at, + ); + + let refresh_at = allocation.poll_timeout().unwrap(); + assert_eq!(refresh_at, received_at + (ALLOCATION_LIFETIME / 2)); + + allocation.handle_timeout(refresh_at); + let next_msg = allocation.next_message().unwrap(); + assert_eq!(next_msg.method(), REFRESH) + } + + #[test] + fn failed_refresh_resets_allocation_lifetime() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + allocation.advance_to_next_timeout(); + + let refresh = allocation.next_message().unwrap(); + allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // These ones are not retried. + + assert_eq!(allocation.poll_timeout(), None); + } + + #[test] + fn when_refreshed_with_no_allocation_after_failed_response_tries_to_allocate() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); + + allocation.refresh_with_same_credentials(); + + let next_msg = allocation.next_message().unwrap(); + assert_eq!(next_msg.method(), ALLOCATE) + } + + #[test] + fn failed_allocation_clears_buffered_channel_bindings() { + let mut allocation = Allocation::for_test(Instant::now()); + + allocation.bind_channel(PEER1, Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings. + + allocation.refresh_with_same_credentials(); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + let next_msg = allocation.next_message(); + assert!(next_msg.is_none()) + } + + #[test] + fn failed_allocation_is_suspended() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings. + + assert!(allocation.is_suspended()) + } + fn ch(peer: SocketAddr, now: Instant) -> Channel { Channel { peer, @@ -1533,7 +1688,7 @@ mod tests { message.add_attribute(XorRelayAddress::new(*addr)); } - message.add_attribute(Lifetime::new(Duration::from_secs(600)).unwrap()); + message.add_attribute(Lifetime::new(ALLOCATION_LIFETIME).unwrap()); encode(message) } @@ -1551,6 +1706,17 @@ mod tests { encode(message) } + fn server_error(request: &Message) -> Vec { + let mut message = Message::new( + MessageClass::ErrorResponse, + request.method(), + request.transaction_id(), + ); + message.add_attribute(ErrorCode::from(ServerError)); + + encode(message) + } + fn stale_nonce_response(request: &Message, nonce: Nonce) -> Vec { let mut message = Message::new( MessageClass::ErrorResponse, @@ -1622,5 +1788,19 @@ mod tests { fn handle_test_input(&mut self, packet: &[u8], now: Instant) { self.handle_input(RELAY, PEER1, packet, now); } + + fn advance_to_next_timeout(&mut self) { + if let Some(next) = self.poll_timeout() { + self.handle_timeout(next) + } + } + + fn refresh_with_same_credentials(&mut self) { + self.refresh( + Username::new("foobar".to_owned()).unwrap(), + "baz", + Realm::new("firezone".to_owned()).unwrap(), + ); + } } } diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index b76a8da9d..8025fc25a 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -814,14 +814,6 @@ where fn upsert_turn_servers(&mut self, servers: &HashSet<(SocketAddr, String, String, String)>) { for (server, username, password, realm) in servers { - if let Some(existing) = self.allocations.get_mut(server) { - if existing.uses_credentials(username, password, realm) { - existing.refresh(); - - continue; - } - } - let Ok(username) = Username::new(username.to_owned()) else { tracing::debug!(%username, "Invalid TURN username"); continue; @@ -831,16 +823,17 @@ where continue; }; - let existing = self.allocations.insert( + if let Some(existing) = self.allocations.get_mut(server) { + existing.refresh(username, password, realm); + continue; + } + + self.allocations.insert( *server, Allocation::new(*server, username, password.clone(), realm, self.last_now), ); - if existing.is_some() { - tracing::info!(address = %server, "Replaced existing allocation because credentials to TURN server changed"); - } else { - tracing::info!(address = %server, "Added new TURN server"); - } + tracing::info!(address = %server, "Added new TURN server"); } } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 3f712c14c..674e1b20b 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -36,39 +36,6 @@ fn answer_after_stale_connection_does_not_panic() { alice.accept_answer(1, bob.public_key(), answer); } -#[test] -fn reinitialize_allocation_if_credentials_for_relay_differ() { - let mut alice = ClientNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - Instant::now(), - ); - - // Make a new connection that uses RELAY with initial set of credentials - let _ = alice.new_connection( - 1, - HashSet::new(), - HashSet::from([relay("user1", "pass1", "realm1")]), - ); - - let transmit = alice.poll_transmit().unwrap(); - assert_eq!(transmit.dst, RELAY); - assert!(alice.poll_transmit().is_none()); - - // Make another connection, using the same relay but different credentials (happens when the relay restarts) - - let _ = alice.new_connection( - 2, - HashSet::new(), - HashSet::from([relay("user2", "pass2", "realm1")]), - ); - - // Expect to send another message to the "new" relay - let transmit = alice.poll_transmit().unwrap(); - assert_eq!(transmit.dst, RELAY); - assert_eq!(&transmit.payload[..2], [0x0, 0x3]); // `ALLOCATE` is 0x0003: https://www.rfc-editor.org/rfc/rfc8656#name-stun-methods - assert!(alice.poll_transmit().is_none()); -} - #[test] fn second_connection_with_same_relay_reuses_allocation() { let mut alice = ClientNode::::new(