diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index 1cf298a1d..7abd7abdb 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -131,6 +131,10 @@ impl Allocation { .flatten() } + pub fn uses_credentials(&self, username: &str, password: &str, realm: &str) -> bool { + self.username.name() == username && self.password == password && self.realm.text() == realm + } + #[tracing::instrument(level = "debug", skip(self, packet, now), fields(relay = %self.server, id, method, class, rtt))] pub fn handle_input( &mut self, diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index cc0f920a8..17e884849 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -779,7 +779,7 @@ where fn upsert_stun_servers(&mut self, servers: &HashSet) { for server in servers { if !self.bindings.contains_key(server) { - tracing::debug!(address = %server, "Adding new STUN server"); + tracing::info!(address = %server, "Adding new STUN server"); self.bindings .insert(*server, StunBinding::new(*server, self.last_now)); @@ -789,6 +789,14 @@ where fn upsert_turn_servers(&mut self, servers: &HashSet<(SocketAddr, String, String, String)>) { for (server, username, password, realm) in servers { + if self + .allocations + .get(server) + .is_some_and(|a| a.uses_credentials(username, password, realm)) + { + continue; + } + let Ok(username) = Username::new(username.to_owned()) else { tracing::debug!(%username, "Invalid TURN username"); continue; @@ -798,13 +806,15 @@ where continue; }; - if !self.allocations.contains_key(server) { - tracing::debug!(address = %server, "Adding new TURN server"); + let existing = self.allocations.insert( + *server, + Allocation::new(*server, username, password.clone(), realm, self.last_now), + ); - self.allocations.insert( - *server, - Allocation::new(*server, username, password.clone(), realm, self.last_now), - ); + if existing.is_some() { + tracing::info!(address = %server, "Replaced existing allocation because credentials to TURN server changed"); + } else { + tracing::info!(address = %server, "Added new TURN server"); } } } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 7f851a2ca..5a30f8148 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -2,6 +2,7 @@ use boringtun::x25519::StaticSecret; use snownet::{ClientNode, Event}; use std::{ collections::HashSet, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, time::{Duration, Instant}, }; @@ -17,3 +18,73 @@ fn connection_times_out_after_10_seconds() { assert_eq!(alice.poll_event().unwrap(), Event::ConnectionFailed(1)); } + +#[test] +fn reinitialize_allocation_if_credentials_for_relay_differ() { + let mut alice = ClientNode::::new( + StaticSecret::random_from_rng(rand::thread_rng()), + Instant::now(), + ); + + // Make a new connection that uses RELAY with initial set of credentials + let _ = alice.new_connection( + 1, + HashSet::new(), + HashSet::from([relay("user1", "pass1", "realm1")]), + ); + + let transmit = alice.poll_transmit().unwrap(); + assert_eq!(transmit.dst, RELAY); + assert!(alice.poll_transmit().is_none()); + + // Make another connection, using the same relay but different credentials (happens when the relay restarts) + + let _ = alice.new_connection( + 2, + HashSet::new(), + HashSet::from([relay("user2", "pass2", "realm1")]), + ); + + // Expect to send another message to the "new" relay + let transmit = alice.poll_transmit().unwrap(); + assert_eq!(transmit.dst, RELAY); + assert_eq!(&transmit.payload[..2], [0x0, 0x3]); // `ALLOCATE` is 0x0003: https://www.rfc-editor.org/rfc/rfc8656#name-stun-methods + assert!(alice.poll_transmit().is_none()); +} + +#[test] +fn second_connection_with_same_relay_reuses_allocation() { + let mut alice = ClientNode::::new( + StaticSecret::random_from_rng(rand::thread_rng()), + Instant::now(), + ); + + let _ = alice.new_connection( + 1, + HashSet::new(), + HashSet::from([relay("user1", "pass1", "realm1")]), + ); + + let transmit = alice.poll_transmit().unwrap(); + assert_eq!(transmit.dst, RELAY); + assert!(alice.poll_transmit().is_none()); + + let _ = alice.new_connection( + 2, + HashSet::new(), + HashSet::from([relay("user1", "pass1", "realm1")]), + ); + + assert!(alice.poll_transmit().is_none()); +} + +fn relay(username: &str, pass: &str, realm: &str) -> (SocketAddr, String, String, String) { + ( + RELAY, + username.to_owned(), + pass.to_owned(), + realm.to_owned(), + ) +} + +const RELAY: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000));