diff --git a/rust/connlib/socket-factory/src/lib.rs b/rust/connlib/socket-factory/src/lib.rs index 4fd191d25..1ed313224 100644 --- a/rust/connlib/socket-factory/src/lib.rs +++ b/rust/connlib/socket-factory/src/lib.rs @@ -415,7 +415,9 @@ impl UdpSocket { let (num_received, sender) = self.inner.recv_from(&mut buffer).await?; - if sender != dst { + // Even though scopes are technically important for link-local IPv6 addresses, they can be ignored for our purposes. + // We only want to ensure that the reply is from the expected source after we have already received the packet. + if !is_equal_modulo_scope_for_ipv6_link_local(dst, sender) { return Err(io::Error::other(format!( "Unexpected reply source: {sender}; expected: {dst}" ))); @@ -484,6 +486,22 @@ impl UdpSocket { } } +/// Compares the two [`SocketAddr`]s for equality, ignored IPv6 scopes for link-local addresses. +fn is_equal_modulo_scope_for_ipv6_link_local(expected: SocketAddr, actual: SocketAddr) -> bool { + match (expected, actual) { + (SocketAddr::V6(expected), SocketAddr::V6(mut actual)) + if expected.scope_id() == 0 && actual.ip().is_unicast_link_local() => + { + actual.set_scope_id(0); + + expected == actual + } + (SocketAddr::V4(expected), SocketAddr::V4(actual)) => expected == actual, + (SocketAddr::V6(expected), SocketAddr::V6(actual)) => expected == actual, + (SocketAddr::V6(_), SocketAddr::V4(_)) | (SocketAddr::V4(_), SocketAddr::V6(_)) => false, + } +} + /// An iterator that segments an array of buffers into individual datagrams. /// /// This iterator is generic over its buffer type and the number of buffers to allow easier testing without a buffer pool. @@ -619,7 +637,7 @@ where #[cfg(test)] mod tests { use gat_lending_iterator::LendingIterator as _; - use std::net::Ipv4Addr; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV6}; use super::*; @@ -669,4 +687,22 @@ mod tests { assert_eq!(iter.next().unwrap().packet, b"foo"); assert!(iter.next().is_none()); } + + #[test] + fn scopes_are_ignored_for_link_local_addresses() { + let left = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0), + 1000, + 0, + 0, + )); + let right = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0), + 1000, + 0, + 42, + )); + + assert!(is_equal_modulo_scope_for_ipv6_link_local(left, right)) + } }