diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 9ea5c3d87..3ce2aabf9 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -131,8 +131,29 @@ impl Allocation { .flatten() } - pub fn uses_credentials(&self, username: &str, password: &str, realm: &str) -> bool { - self.username.name() == username && self.password == password && self.realm.text() == realm + /// Refresh this allocation. + /// + /// In case refreshing the allocation fails, we will attempt to make a new one. + pub fn refresh(&mut self, username: Username, password: &str, realm: Realm) { + self.username = username; + self.realm = realm; + self.password = password.to_owned(); + + if !self.has_allocation() && self.allocate_in_flight() { + tracing::debug!("Not refreshing allocation because we are already making one"); + return; + } + + if self.is_suspended() { + tracing::debug!("Attempting to make a new allocation"); + + self.authenticate_and_queue(make_allocate_request()); + return; + } + + tracing::debug!("Refreshing allocation"); + + self.authenticate_and_queue(make_refresh_request()); } #[tracing::instrument(level = "debug", skip(self, packet, now), fields(relay = %self.server, id, method, class, rtt))] @@ -205,6 +226,9 @@ impl Allocation { } match message.method() { + ALLOCATE => { + self.buffered_channel_bindings.clear(); + } CHANNEL_BIND => { let Some(channel) = original_request .get_attribute::() @@ -226,6 +250,9 @@ impl Allocation { } self.channel_bindings.clear(); + self.allocation_lifetime = None; + + self.authenticate_and_queue(make_allocate_request()); } _ => {} } @@ -393,7 +420,7 @@ impl Allocation { } if let Some(refresh_at) = self.refresh_allocation_at() { - if now > refresh_at { + if now >= refresh_at { tracing::debug!("Allocation is due for a refresh"); self.authenticate_and_queue(make_refresh_request()); } @@ -525,14 +552,6 @@ impl Allocation { }) } - pub fn refresh(&mut self) { - if !self.has_allocation() { - return; - } - - self.authenticate_and_queue(make_refresh_request()); - } - fn has_allocation(&self) -> bool { self.ip4_allocation.is_some() || self.ip6_allocation.is_some() } @@ -552,6 +571,24 @@ impl Allocation { }) } + fn allocate_in_flight(&self) -> bool { + self.sent_requests + .values() + .any(|(r, _, _)| r.method() == ALLOCATE) + } + + /// Check whether this allocation is suspended. + /// + /// We call it suspended if we have given up making an allocation due to some error. + fn is_suspended(&self) -> bool { + let no_allocation = !self.has_allocation(); + let nothing_in_flight = self.sent_requests.is_empty(); + let nothing_buffered = self.buffered_transmits.is_empty(); + let waiting_on_nothing = self.poll_timeout().is_none(); + + no_allocation && nothing_in_flight && nothing_buffered && waiting_on_nothing + } + fn authenticate(&self, message: Message) -> Message { let attributes = message .attributes() @@ -949,6 +986,10 @@ impl BufferedChannelBindings { fn pop_front(&mut self) -> Option> { self.inner.pop_front() } + + fn clear(&mut self) { + self.inner.clear() + } } #[cfg(test)] @@ -958,7 +999,11 @@ mod tests { iter, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; - use stun_codec::{rfc5389::errors::BadRequest, rfc5766::errors::AllocationMismatch, Message}; + use stun_codec::{ + rfc5389::errors::{BadRequest, ServerError}, + rfc5766::errors::AllocationMismatch, + Message, + }; const PEER1: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 10000); @@ -971,6 +1016,8 @@ mod tests { const MINUTE: Duration = Duration::from_secs(60); + const ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); + #[test] fn returns_first_available_channel() { let mut channel_bindings = ChannelBindings::default(); @@ -1419,7 +1466,7 @@ mod tests { } #[test] - fn calling_refresh_will_trigger_refresh() { + fn calling_refresh_with_same_credentials_will_trigger_refresh() { let mut allocation = Allocation::for_test(Instant::now()); let allocate = allocation.next_message().unwrap(); @@ -1428,10 +1475,13 @@ mod tests { Instant::now(), ); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); assert_eq!(refresh.method(), REFRESH); + + let lifetime = refresh.get_attribute::(); + assert!(lifetime.is_none() || lifetime.is_some_and(|l| l.lifetime() != Duration::ZERO)); } #[test] @@ -1445,7 +1495,7 @@ mod tests { ); let _ = iter::from_fn(|| allocation.poll_event()).collect::>(); // Drain events. - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); @@ -1490,7 +1540,7 @@ mod tests { let msg = allocation.encode_to_vec(PEER2_IP4, b"foobar", Instant::now()); assert!(msg.is_some(), "expect to have a channel to peer"); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let refresh = allocation.next_message().unwrap(); allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); @@ -1505,12 +1555,117 @@ mod tests { let _allocate = allocation.next_message().unwrap(); - allocation.refresh(); + allocation.refresh_with_same_credentials(); let next_msg = allocation.next_message(); assert!(next_msg.is_none()) } + #[test] + fn failed_refresh_attempts_to_make_new_allocation() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + allocation.refresh_with_same_credentials(); + + let refresh = allocation.next_message().unwrap(); + allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); + + let allocate = allocation.next_message().unwrap(); + assert_eq!(allocate.method(), ALLOCATE); + } + + #[test] + fn allocation_is_refreshed_after_half_its_lifetime() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + + let received_at = Instant::now(); + + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + received_at, + ); + + let refresh_at = allocation.poll_timeout().unwrap(); + assert_eq!(refresh_at, received_at + (ALLOCATION_LIFETIME / 2)); + + allocation.handle_timeout(refresh_at); + let next_msg = allocation.next_message().unwrap(); + assert_eq!(next_msg.method(), REFRESH) + } + + #[test] + fn failed_refresh_resets_allocation_lifetime() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + allocation.advance_to_next_timeout(); + + let refresh = allocation.next_message().unwrap(); + allocation.handle_test_input(&failed_refresh(&refresh), Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // These ones are not retried. + + assert_eq!(allocation.poll_timeout(), None); + } + + #[test] + fn when_refreshed_with_no_allocation_after_failed_response_tries_to_allocate() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); + + allocation.refresh_with_same_credentials(); + + let next_msg = allocation.next_message().unwrap(); + assert_eq!(next_msg.method(), ALLOCATE) + } + + #[test] + fn failed_allocation_clears_buffered_channel_bindings() { + let mut allocation = Allocation::for_test(Instant::now()); + + allocation.bind_channel(PEER1, Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings. + + allocation.refresh_with_same_credentials(); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input( + &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), + Instant::now(), + ); + + let next_msg = allocation.next_message(); + assert!(next_msg.is_none()) + } + + #[test] + fn failed_allocation_is_suspended() { + let mut allocation = Allocation::for_test(Instant::now()); + + let allocate = allocation.next_message().unwrap(); + allocation.handle_test_input(&server_error(&allocate), Instant::now()); // This should clear the buffered channel bindings. + + assert!(allocation.is_suspended()) + } + fn ch(peer: SocketAddr, now: Instant) -> Channel { Channel { peer, @@ -1533,7 +1688,7 @@ mod tests { message.add_attribute(XorRelayAddress::new(*addr)); } - message.add_attribute(Lifetime::new(Duration::from_secs(600)).unwrap()); + message.add_attribute(Lifetime::new(ALLOCATION_LIFETIME).unwrap()); encode(message) } @@ -1551,6 +1706,17 @@ mod tests { encode(message) } + fn server_error(request: &Message) -> Vec { + let mut message = Message::new( + MessageClass::ErrorResponse, + request.method(), + request.transaction_id(), + ); + message.add_attribute(ErrorCode::from(ServerError)); + + encode(message) + } + fn stale_nonce_response(request: &Message, nonce: Nonce) -> Vec { let mut message = Message::new( MessageClass::ErrorResponse, @@ -1622,5 +1788,19 @@ mod tests { fn handle_test_input(&mut self, packet: &[u8], now: Instant) { self.handle_input(RELAY, PEER1, packet, now); } + + fn advance_to_next_timeout(&mut self) { + if let Some(next) = self.poll_timeout() { + self.handle_timeout(next) + } + } + + fn refresh_with_same_credentials(&mut self) { + self.refresh( + Username::new("foobar".to_owned()).unwrap(), + "baz", + Realm::new("firezone".to_owned()).unwrap(), + ); + } } } diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index b76a8da9d..8025fc25a 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -814,14 +814,6 @@ where fn upsert_turn_servers(&mut self, servers: &HashSet<(SocketAddr, String, String, String)>) { for (server, username, password, realm) in servers { - if let Some(existing) = self.allocations.get_mut(server) { - if existing.uses_credentials(username, password, realm) { - existing.refresh(); - - continue; - } - } - let Ok(username) = Username::new(username.to_owned()) else { tracing::debug!(%username, "Invalid TURN username"); continue; @@ -831,16 +823,17 @@ where continue; }; - let existing = self.allocations.insert( + if let Some(existing) = self.allocations.get_mut(server) { + existing.refresh(username, password, realm); + continue; + } + + self.allocations.insert( *server, Allocation::new(*server, username, password.clone(), realm, self.last_now), ); - if existing.is_some() { - tracing::info!(address = %server, "Replaced existing allocation because credentials to TURN server changed"); - } else { - tracing::info!(address = %server, "Added new TURN server"); - } + tracing::info!(address = %server, "Added new TURN server"); } } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 3f712c14c..674e1b20b 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -36,39 +36,6 @@ fn answer_after_stale_connection_does_not_panic() { alice.accept_answer(1, bob.public_key(), answer); } -#[test] -fn reinitialize_allocation_if_credentials_for_relay_differ() { - let mut alice = ClientNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - Instant::now(), - ); - - // Make a new connection that uses RELAY with initial set of credentials - let _ = alice.new_connection( - 1, - HashSet::new(), - HashSet::from([relay("user1", "pass1", "realm1")]), - ); - - let transmit = alice.poll_transmit().unwrap(); - assert_eq!(transmit.dst, RELAY); - assert!(alice.poll_transmit().is_none()); - - // Make another connection, using the same relay but different credentials (happens when the relay restarts) - - let _ = alice.new_connection( - 2, - HashSet::new(), - HashSet::from([relay("user2", "pass2", "realm1")]), - ); - - // Expect to send another message to the "new" relay - let transmit = alice.poll_transmit().unwrap(); - assert_eq!(transmit.dst, RELAY); - assert_eq!(&transmit.payload[..2], [0x0, 0x3]); // `ALLOCATE` is 0x0003: https://www.rfc-editor.org/rfc/rfc8656#name-stun-methods - assert!(alice.poll_transmit().is_none()); -} - #[test] fn second_connection_with_same_relay_reuses_allocation() { let mut alice = ClientNode::::new(