diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b43cabb29..8c376dd3c 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -6044,6 +6044,7 @@ dependencies = [ "boringtun", "bytecodec", "bytes", + "derive_more 1.0.0", "firezone-logging", "hex", "hex-display", diff --git a/rust/connlib/snownet/Cargo.toml b/rust/connlib/snownet/Cargo.toml index a7f53d379..d2a1f88b8 100644 --- a/rust/connlib/snownet/Cargo.toml +++ b/rust/connlib/snownet/Cargo.toml @@ -9,6 +9,7 @@ backoff = { workspace = true } boringtun = { workspace = true } bytecodec = { workspace = true } bytes = { workspace = true } +derive_more = { workspace = true, features = ["debug"] } firezone-logging = { workspace = true } hex = { workspace = true } hex-display = { workspace = true } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index c1b5e35bb..6f524aeca 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -37,6 +37,14 @@ use tracing::{field, Span}; const REQUEST_TIMEOUT: Duration = Duration::from_secs(1); +/// How often to send a STUN binding request after the initial connection to the relay. +/// +/// Most NATs keep _confirmed_ UDP bindings around for 120s. +/// Unconfirmed UDP bindings are usually kept around for 30s. +/// The binding interval here is chosen very conservatively to reflect these. +/// It ain't much traffic and with a lower interval, these checks can also help in disconnecting from an unresponsive relay. +const BINDING_INTERVAL: Duration = Duration::from_secs(25); + /// Represents a TURN allocation that refreshes itself. /// /// Allocations have a lifetime and need to be continuously refreshed to stay active. @@ -51,7 +59,10 @@ pub struct Allocation { /// /// To figure out, how to communicate with the relay, we start by sending a BINDING request on all known sockets. /// Whatever comes back first, wins. - active_socket: Option, + /// + /// Once set, we send STUN binding requests at an interval of [`BINDING_INTERVAL`]. + /// This ensures any NAT bindings stay alive even if the allocation is completely idle. + active_socket: Option, software: Software, @@ -91,6 +102,13 @@ pub struct Allocation { explicit_failure: Option, } +#[derive(derive_more::Debug, Clone, Copy)] +#[debug("{addr}")] +struct ActiveSocket { + addr: SocketAddr, + next_binding: Instant, +} + #[derive(Debug, PartialEq)] pub(crate) enum Event { New(Candidate), @@ -503,14 +521,14 @@ impl Allocation { // We send 2 BINDING requests to start with (one for each IP version) and the first one coming back wins. // Thus, if we already have a socket set, we are done with processing this binding request. - if let Some(active_socket) = self.active_socket { - tracing::debug!(%active_socket, additional_socket = %original_dst, "Relay supports dual-stack but we've already picked a socket"); + if let Some(active_socket) = self.active_socket.as_ref() { + tracing::debug!(active_socket = %active_socket.addr, additional_socket = %original_dst, "Relay supports dual-stack but we've already picked a socket"); return true; } // If the socket isn't set yet, use the `original_dst` as the primary socket. - self.active_socket = Some(original_dst); + self.active_socket = Some(ActiveSocket::new(original_dst, now)); tracing::debug!(active_socket = %original_dst, "Updating active socket"); @@ -643,6 +661,16 @@ impl Allocation { self.invalidate_allocation(); } + if self.has_allocation() { + if let Some(addr) = self + .active_socket + .as_mut() + .and_then(|a| a.handle_timeout(now)) + { + self.queue(addr, make_binding_request(self.software.clone()), None); + } + } + while let Some(timed_out_request) = self.sent_requests .iter() @@ -655,24 +683,28 @@ impl Allocation { .remove(&timed_out_request) .expect("ID is from list"); - tracing::debug!(id = ?request.transaction_id(), method = %request.method(), %dst, "Request timed out after {backoff_duration:?}, re-sending"); + let method = request.method(); - let needs_auth = request.method() != BINDING; - let is_refresh = request.method() == REFRESH; + tracing::debug!(id = ?request.transaction_id(), %method, %dst, "Request timed out after {backoff_duration:?}, re-sending"); - if needs_auth { - let queued = self.authenticate_and_queue(request, Some(backoff)); + let needs_auth = method != BINDING; - // If we fail to queue the refresh message because we've exceeded our backoff, give up. - if !queued && is_refresh { - self.active_socket = None; // The socket seems to no longer be reachable. - self.invalidate_allocation(); - } + let queued = if needs_auth { + self.authenticate_and_queue(request, Some(backoff)) + } else { + self.queue(dst, request, Some(backoff)) + }; - continue; + // If we have an active socket (i.e. successfully sent at least 1 BINDING request) + // and we just timed out a message, invalidate the allocation. + if !queued + && self + .active_socket + .is_some_and(|s| s.same_ip_version_as(dst)) + { + self.active_socket = None; // The socket seems to no longer be reachable. + self.invalidate_allocation(); } - - self.queue(dst, request, Some(backoff)); } if let Some(refresh_at) = self.refresh_allocation_at() { @@ -719,7 +751,13 @@ impl Allocation { earliest_timeout = earliest(earliest_timeout, Some(*sent_at + *backoff)); } - earliest_timeout + let next_keepalive = if self.has_allocation() { + self.active_socket.map(|a| a.next_binding) + } else { + None + }; + + earliest(earliest_timeout, next_keepalive) } #[tracing::instrument(level = "debug", skip(self, now), fields(active_socket = ?self.active_socket))] @@ -774,7 +812,7 @@ impl Allocation { buffer: &mut [u8], now: Instant, ) -> Option { - let active_socket = self.active_socket?; + let active_socket = self.active_socket?.addr; let payload_length = buffer.len() - 4; let channel_number = match self.channel_bindings.connected_channel_to_peer(peer, now) { @@ -991,7 +1029,7 @@ impl Allocation { message: Message, backoff: Option, ) -> bool { - let Some(dst) = self.active_socket else { + let Some(active_socket) = self.active_socket else { tracing::debug!( "Unable to queue {} because we haven't nominated a socket yet", message.method() @@ -1008,7 +1046,7 @@ impl Allocation { }; let authenticated_message = authenticate(message, credentials); - self.queue(dst, authenticated_message, backoff) + self.queue(active_socket.addr, authenticated_message, backoff) } fn queue( @@ -1084,6 +1122,29 @@ pub struct EncodeOk { pub socket: SocketAddr, } +impl ActiveSocket { + fn new(addr: SocketAddr, now: Instant) -> Self { + Self { + addr, + next_binding: now + BINDING_INTERVAL, + } + } + + fn same_ip_version_as(&self, dst: SocketAddr) -> bool { + self.addr.is_ipv4() == dst.is_ipv4() + } + + fn handle_timeout(&mut self, now: Instant) -> Option { + if now < self.next_binding { + return None; + } + + self.next_binding = now + BINDING_INTERVAL; + + Some(self.addr) + } +} + fn authenticate(message: Message, credentials: &Credentials) -> Message { let attributes = message .attributes() @@ -2158,23 +2219,21 @@ mod tests { #[test] fn allocation_is_refreshed_after_half_its_lifetime() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); + let mut now = Instant::now(); + let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1); let allocate = allocation.next_message().unwrap(); - let received_at = Instant::now(); - allocation.handle_test_input_ip4( &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), - received_at, + now, ); - let refresh_at = allocation.poll_timeout().unwrap(); - assert_eq!(refresh_at, received_at + (ALLOCATION_LIFETIME / 2)); + now += ALLOCATION_LIFETIME / 2; + allocation.handle_timeout(now); - allocation.handle_timeout(refresh_at); - let next_msg = allocation.next_message().unwrap(); - assert_eq!(next_msg.method(), REFRESH) + let refresh = iter::from_fn(|| allocation.next_message()).find(|m| m.method() == REFRESH); + assert!(refresh.is_some()); } #[test] @@ -2194,27 +2253,6 @@ mod tests { assert!(allocation.poll_timeout().unwrap() > refresh_at); } - #[test] - fn failed_refresh_resets_allocation_lifetime() { - let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); - - let allocate = allocation.next_message().unwrap(); - allocation.handle_test_input_ip4( - &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), - Instant::now(), - ); - - allocation.advance_to_next_timeout(); - - let refresh = allocation.next_message().unwrap(); - allocation.handle_test_input_ip4(&allocation_mismatch(&refresh), Instant::now()); - - let allocate = allocation.next_message().unwrap(); - allocation.handle_test_input_ip4(&server_error(&allocate), Instant::now()); // These ones are not retried. - - assert_eq!(allocation.poll_timeout(), None); - } - #[test] fn when_refreshed_with_no_allocation_after_failed_response_tries_to_allocate() { let mut allocation = Allocation::for_test_ip4(Instant::now()).with_binding_response(PEER1); @@ -2377,33 +2415,33 @@ mod tests { } #[test] - fn timed_out_refresh_requests_invalid_candidates() { + fn timed_out_binding_requests_invalid_candidates() { let _guard = firezone_logging::test("trace"); - let start = Instant::now(); - let mut allocation = Allocation::for_test_ip4(start).with_binding_response(PEER1); + let mut now = Instant::now(); + let mut allocation = Allocation::for_test_ip4(now).with_binding_response(PEER1); // Make an allocation { let allocate = allocation.next_message().unwrap(); allocation.handle_test_input_ip4( &allocate_response(&allocate, &[RELAY_ADDR_IP4, RELAY_ADDR_IP6]), - start, + now, ); let _drained_events = iter::from_fn(|| allocation.poll_event()).collect::>(); } - // Test that we refresh it. + // Test that we send binding requests it. { - let refresh_at = allocation.poll_timeout().unwrap(); - allocation.handle_timeout(refresh_at); + now = allocation.poll_timeout().unwrap(); + allocation.handle_timeout(now); - let refresh = allocation.next_message().unwrap(); - assert_eq!(refresh.method(), REFRESH); + let binding = allocation.next_message().unwrap(); + assert_eq!(binding.method(), BINDING); } - // Simulate refresh timing out - for _ in backoff::steps(start) { + // Simulate bindings timing out + for _ in backoff::steps(now) { allocation.handle_timeout(allocation.poll_timeout().unwrap()); } @@ -2616,6 +2654,25 @@ mod tests { ); } + #[test] + fn sends_binding_request_on_nominated_socket() { + let mut now = Instant::now(); + + let mut allocation = Allocation::for_test_ip4(now) + .with_binding_response(PEER1) + .with_allocate_response(&[RELAY_ADDR_IP4]); + + now += BINDING_INTERVAL; + allocation.handle_timeout(now); + + let transmit = allocation.poll_transmit().unwrap(); + assert_eq!(transmit.dst, RELAY_V4.into()); + assert_eq!( + decode(&transmit.payload).unwrap().unwrap().method(), + BINDING + ); + } + fn ch(peer: SocketAddr, now: Instant) -> Channel { Channel { peer, @@ -2776,12 +2833,6 @@ mod tests { self.handle_input(RELAY_V4.into(), PEER1, packet, now) } - fn advance_to_next_timeout(&mut self) { - if let Some(next) = self.poll_timeout() { - self.handle_timeout(next) - } - } - fn refresh_with_same_credentials(&mut self) { self.refresh(Instant::now()); }