From 632dfdd888fa19a008780b1a8e613da1783d2c84 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 2 Aug 2023 03:50:43 +0200 Subject: [PATCH] feat(relay): support IPv6 allocations (#1814) This patch series adds support for IPv6 allocations. If not specified otherwise in the ALLOCATE request, clients will get an IP4 allocation. They can also request an IPv6 address or an additional IPv6 address in addition to their IPv4 address. Either of those is only possible if the relay actually has a listening socket for the requested address family. The CLI is designed such that the user can either specify IP4, IP6 or both of them. The `Server` component handles all of this logic and responds with either a successful allocation response or an Address Family Not Supported error (see https://www.rfc-editor.org/rfc/rfc8656#name-stun-error-response-codes). Multiple refactorings were necessary to achieve this design, they are all extracted into separate PRs: Depends-On: #1831. Depends-On: #1832. Depends-On: #1833. --------- Co-authored-by: Jamil --- rust/Cargo.lock | 18 +- rust/relay/Cargo.toml | 2 + rust/relay/src/allocation.rs | 12 +- rust/relay/src/lib.rs | 49 ++++ rust/relay/src/main.rs | 139 +++++++---- rust/relay/src/net_ext.rs | 28 +++ rust/relay/src/rfc8656.rs | 299 ++++++++++++++++++++++++ rust/relay/src/server.rs | 222 +++++++++++++++--- rust/relay/src/server/client_message.rs | 116 +++++++-- rust/relay/src/udp_socket.rs | 33 ++- rust/relay/tests/regression.rs | 103 ++++++-- 11 files changed, 889 insertions(+), 132 deletions(-) create mode 100644 rust/relay/src/net_ext.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ef019af07..e48f26521 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1511,7 +1511,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -2548,12 +2548,14 @@ dependencies = [ "redis", "serde", "sha2", + "socket2 0.5.3", "stun_codec", "test-strategy", "tokio", "tracing", "tracing-stackdriver", "tracing-subscriber", + "trackable 1.3.0", "url", "uuid", "webrtc", @@ -2977,6 +2979,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -3294,7 +3306,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.4.9", "tokio-macros", "windows-sys 0.48.0", ] @@ -3917,7 +3929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f08dfd7a6e3987e255c4dbe710dde5d94d0f0574f8a21afa95d171376c143106" dependencies = [ "log", - "socket2", + "socket2 0.4.9", "thiserror", "tokio", "webrtc-util", diff --git a/rust/relay/Cargo.toml b/rust/relay/Cargo.toml index 9e029582e..8ff309a0c 100644 --- a/rust/relay/Cargo.toml +++ b/rust/relay/Cargo.toml @@ -28,6 +28,8 @@ uuid = { version = "1.4.1", features = ["v4"] } phoenix-channel = { path = "../phoenix-channel" } url = "2.4.0" serde = { version = "1.0.179", features = ["derive"] } +trackable = "1.3.0" +socket2 = "0.5.3" prometheus-client = "0.21.1" axum = { version = "0.6.18", default-features = false, features = ["http1", "tokio"] } diff --git a/rust/relay/src/allocation.rs b/rust/relay/src/allocation.rs index fadaf8094..bb527ac22 100644 --- a/rust/relay/src/allocation.rs +++ b/rust/relay/src/allocation.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; use std::convert::Infallible; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, SocketAddr}; use tokio::task; /// The maximum amount of items that can be buffered in the channel to the allocation task. @@ -24,17 +24,17 @@ impl Allocation { pub fn new( relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, id: AllocationId, - listen_ip4_addr: Ipv4Addr, + listen_addr: IpAddr, port: u16, ) -> Self { let (client_to_peer_sender, client_to_peer_receiver) = mpsc::channel(MAX_BUFFERED_ITEMS); let task = tokio::spawn(async move { - let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_ip4_addr, port).await else { + let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_addr, port).await else { unreachable!() }; - tracing::warn!("Allocation task for {id} failed: {e}"); + tracing::warn!(allocation = %id, %listen_addr, "Allocation task failed: {e:#}"); // With the task stopping, the channel will be closed and any attempt to send data to it will fail. }); @@ -83,10 +83,10 @@ async fn forward_incoming_relay_data( mut relayed_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, mut client_to_peer_receiver: mpsc::Receiver<(Vec, SocketAddr)>, id: AllocationId, - listen_ip4_addr: Ipv4Addr, + listen_addr: IpAddr, port: u16, ) -> Result { - let mut socket = UdpSocket::bind((listen_ip4_addr, port)).await?; + let mut socket = UdpSocket::bind((listen_addr, port))?; loop { tokio::select! { diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index b6dfeab43..8450c0b93 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -1,5 +1,6 @@ mod allocation; mod auth; +mod net_ext; mod rfc8656; mod server; mod sleep; @@ -12,6 +13,8 @@ pub mod metrics; pub mod proptest; pub use allocation::Allocation; +pub use net_ext::{IpAddrExt, SocketAddrExt}; +pub use rfc8656::AddressFamily; pub use server::{ Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, CreatePermission, Refresh, Server, @@ -20,3 +23,49 @@ pub use sleep::Sleep; pub use udp_socket::UdpSocket; pub(crate) use time_events::TimeEvents; + +use std::net::{Ipv4Addr, Ipv6Addr}; + +/// Describes the IP stack of a relay server. +#[derive(Debug, Copy, Clone)] +pub enum IpStack { + Ip4(Ipv4Addr), + Ip6(Ipv6Addr), + Dual { ip4: Ipv4Addr, ip6: Ipv6Addr }, +} + +impl IpStack { + pub fn as_v4(&self) -> Option<&Ipv4Addr> { + match self { + IpStack::Ip4(ip4) => Some(ip4), + IpStack::Ip6(_) => None, + IpStack::Dual { ip4, .. } => Some(ip4), + } + } + + pub fn as_v6(&self) -> Option<&Ipv6Addr> { + match self { + IpStack::Ip4(_) => None, + IpStack::Ip6(ip6) => Some(ip6), + IpStack::Dual { ip6, .. } => Some(ip6), + } + } +} + +impl From for IpStack { + fn from(value: Ipv4Addr) -> Self { + IpStack::Ip4(value) + } +} + +impl From for IpStack { + fn from(value: Ipv6Addr) -> Self { + IpStack::Ip6(value) + } +} + +impl From<(Ipv4Addr, Ipv6Addr)> for IpStack { + fn from((ip4, ip6): (Ipv4Addr, Ipv6Addr)) -> Self { + IpStack::Dual { ip4, ip6 } + } +} diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 58d8cb15e..e1ae2625e 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -6,11 +6,14 @@ use phoenix_channel::{Error, Event, PhoenixChannel}; use prometheus_client::registry::Registry; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use relay::{Allocation, AllocationId, Command, Server, Sleep, UdpSocket}; +use relay::{ + AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt, + UdpSocket, +}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::convert::Infallible; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; use std::time::SystemTime; @@ -24,14 +27,24 @@ use url::Url; struct Args { /// The public (i.e. internet-reachable) IPv4 address of the relay server. /// - /// Must route to the local interface we listen on. + /// Must route to the local IPv4 interface we listen on. #[arg(long, env)] - public_ip4_addr: Ipv4Addr, - /// The address of the local interface we should listen on. + public_ip4_addr: Option, + /// The address of the local IPv4 interface we should listen on. /// /// Must not be a wildcard-address. #[arg(long, env)] - listen_ip4_addr: Ipv4Addr, + listen_ip4_addr: Option, + /// The public (i.e. internet-reachable) IPv6 address of the relay server. + /// + /// Must route to the local IP6 interface we listen on. + #[arg(long, env)] + public_ip6_addr: Option, + /// The address of the local IP6 interface we should listen on. + /// + /// Must not be a wildcard-address. + #[arg(long, env)] + listen_ip6_addr: Option, /// The address of the local interface where we should serve the prometheus metrics. /// /// The metrics will be available at `http:///metrics`. @@ -74,13 +87,25 @@ async fn main() -> Result<()> { tracing_subscriber::fmt().with_env_filter(env_filter).init() } - let mut metric_registry = Registry::with_prefix("relay"); + let listen_addr = match (args.listen_ip4_addr, args.listen_ip6_addr) { + (Some(ip4), Some(ip6)) => IpStack::Dual { ip4, ip6 }, + (Some(ip4), None) => IpStack::Ip4(ip4), + (None, Some(ip6)) => IpStack::Ip6(ip6), + (None, None) => { + bail!("Must listen on at least one of IPv4 or IPv6") + } + }; + let public_addr = match (args.public_ip4_addr, args.public_ip6_addr, listen_addr) { + (Some(ip4), Some(ip6), IpStack::Dual { .. }) => IpStack::Dual { ip4, ip6 }, + (Some(ip4), None, IpStack::Ip4(_)) => IpStack::Ip4(ip4), + (None, Some(ip6), IpStack::Ip6(_)) => IpStack::Ip6(ip6), + _ => { + bail!("Must specify a public address for each listen address") + } + }; - let server = Server::new( - args.public_ip4_addr, - make_rng(args.rng_seed), - &mut metric_registry, - ); + let mut metric_registry = Registry::with_prefix("relay"); + let server = Server::new(public_addr, make_rng(args.rng_seed), &mut metric_registry); let channel = if let Some(token) = args.portal_token { let mut url = args.portal_ws_url.clone(); @@ -92,9 +117,16 @@ async fn main() -> Result<()> { } url.set_path("relay/websocket"); - url.query_pairs_mut() - .append_pair("token", &token) - .append_pair("ipv4", &args.listen_ip4_addr.to_string()); + url.query_pairs_mut().append_pair("token", &token); + + if let Some(listen_ip4_addr) = args.listen_ip4_addr { + url.query_pairs_mut() + .append_pair("ipv4", &listen_ip4_addr.to_string()); + } + if let Some(listen_ip6_addr) = args.listen_ip6_addr { + url.query_pairs_mut() + .append_pair("ipv6", &listen_ip6_addr.to_string()); + } let mut channel = PhoenixChannel::::connect( url, @@ -137,8 +169,7 @@ async fn main() -> Result<()> { None }; - let mut eventloop = - Eventloop::new(server, channel, args.listen_ip4_addr, &mut metric_registry).await?; + let mut eventloop = Eventloop::new(server, channel, listen_addr, &mut metric_registry)?; if let Some(metrics_addr) = args.metrics_addr { tokio::spawn(relay::metrics::serve(metrics_addr, metric_registry)); @@ -186,11 +217,12 @@ fn make_rng(seed: Option) -> StdRng { struct Eventloop { inbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, - outbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, - listen_ip4_address: Ipv4Addr, + outbound_ip4_data_sender: mpsc::Sender<(Vec, SocketAddr)>, + outbound_ip6_data_sender: mpsc::Sender<(Vec, SocketAddr)>, + listen_address: IpStack, server: Server, channel: Option>, - allocations: HashMap, + allocations: HashMap<(AllocationId, AddressFamily), Allocation>, relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, sleep: Sleep, @@ -200,27 +232,39 @@ impl Eventloop where R: Rng, { - async fn new( + fn new( server: Server, channel: Option>, - listen_ip4_address: Ipv4Addr, + listen_address: IpStack, _: &mut Registry, ) -> Result { let (relay_data_sender, relay_data_receiver) = mpsc::channel(1); let (inbound_data_sender, inbound_data_receiver) = mpsc::channel(10); - let (outbound_data_sender, outbound_data_receiver) = + let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = + mpsc::channel::<(Vec, SocketAddr)>(10); + let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = mpsc::channel::<(Vec, SocketAddr)>(10); - tokio::spawn(main_udp_socket_task( - listen_ip4_address, - inbound_data_sender, - outbound_data_receiver, - )); + if let Some(ip4) = listen_address.as_v4() { + tokio::spawn(main_udp_socket_task( + (*ip4).into(), + inbound_data_sender.clone(), + outbound_ip4_data_receiver, + )); + } + if let Some(ip6) = listen_address.as_v6() { + tokio::spawn(main_udp_socket_task( + (*ip6).into(), + inbound_data_sender, + outbound_ip6_data_receiver, + )); + } Ok(Self { inbound_data_receiver, - outbound_data_sender, - listen_ip4_address, + outbound_ip4_data_sender, + outbound_ip6_data_sender, + listen_address, server, channel, allocations: Default::default(), @@ -236,7 +280,12 @@ where if let Some(next_command) = self.server.next_command() { match next_command { Command::SendMessage { payload, recipient } => { - if let Err(e) = self.outbound_data_sender.try_send((payload, recipient)) { + let sender = match recipient.family() { + AddressFamily::V4 => &mut self.outbound_ip4_data_sender, + AddressFamily::V6 => &mut self.outbound_ip6_data_sender, + }; + + if let Err(e) = sender.try_send((payload, recipient)) { if e.is_disconnected() { return Poll::Ready(Err(anyhow!( "Channel to primary UDP socket task has been closed" @@ -248,19 +297,19 @@ where } } } - Command::AllocateAddresses { id, port } => { + Command::CreateAllocation { id, family, port } => { + let listen_addr = match family { + AddressFamily::V4 => (*self.listen_address.as_v4().expect("to have listen address for IP4 if we are creating an IP4 allocation")).into(), + AddressFamily::V6 => (*self.listen_address.as_v6().expect("to have listen address for IP6 if we are creating an IP6 allocation")).into(), + }; + self.allocations.insert( - id, - Allocation::new( - self.relay_data_sender.clone(), - id, - self.listen_ip4_address, - port, - ), + (id, family), + Allocation::new(self.relay_data_sender.clone(), id, listen_addr, port), ); } - Command::FreeAddresses { id } => { - if self.allocations.remove(&id).is_none() { + Command::FreeAllocation { id, family } => { + if self.allocations.remove(&(id, family)).is_none() { tracing::debug!("Unknown allocation {id}"); continue; }; @@ -271,10 +320,10 @@ where Pin::new(&mut self.sleep).reset(deadline); } Command::ForwardData { id, data, receiver } => { - let mut allocation = match self.allocations.entry(id) { + let mut allocation = match self.allocations.entry((id, receiver.family())) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => { - tracing::debug!(allocation = %id, "Unknown allocation"); + tracing::debug!(allocation = %id, family = %receiver.family(), "Unknown allocation"); continue; } }; @@ -362,11 +411,11 @@ where } async fn main_udp_socket_task( - listen_ip4_address: Ipv4Addr, + listen_address: IpAddr, mut inbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, mut outbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, ) -> Result { - let mut socket = UdpSocket::bind((listen_ip4_address, 3478)).await?; + let mut socket = UdpSocket::bind((listen_address, 3478))?; loop { tokio::select! { diff --git a/rust/relay/src/net_ext.rs b/rust/relay/src/net_ext.rs new file mode 100644 index 000000000..8807a2664 --- /dev/null +++ b/rust/relay/src/net_ext.rs @@ -0,0 +1,28 @@ +use crate::rfc8656::AddressFamily; +use std::net::{IpAddr, SocketAddr}; + +pub trait IpAddrExt { + fn family(&self) -> AddressFamily; +} + +impl IpAddrExt for IpAddr { + fn family(&self) -> AddressFamily { + match self { + IpAddr::V4(_) => AddressFamily::V4, + IpAddr::V6(_) => AddressFamily::V6, + } + } +} + +pub trait SocketAddrExt { + fn family(&self) -> AddressFamily; +} + +impl SocketAddrExt for SocketAddr { + fn family(&self) -> AddressFamily { + match self { + SocketAddr::V4(_) => AddressFamily::V4, + SocketAddr::V6(_) => AddressFamily::V6, + } + } +} diff --git a/rust/relay/src/rfc8656.rs b/rust/relay/src/rfc8656.rs index 293a81213..aa613be55 100644 --- a/rust/relay/src/rfc8656.rs +++ b/rust/relay/src/rfc8656.rs @@ -2,7 +2,72 @@ // // TODO: Upstream this to `stun-codec`. +use bytecodec::fixnum::{U32beDecoder, U32beEncoder}; +use bytecodec::{ByteCount, Decode, Encode, Eos, ErrorKind, Result, SizedEncode, TryTaggedDecode}; +use std::fmt; use stun_codec::rfc5389::attributes::ErrorCode; +use stun_codec::{Attribute, AttributeType}; +use trackable::{track, track_panic}; + +macro_rules! impl_decode { + ($decoder:ty, $item:ident, $and_then:expr) => { + impl Decode for $decoder { + type Item = $item; + + fn decode(&mut self, buf: &[u8], eos: Eos) -> Result { + track!(self.0.decode(buf, eos)) + } + + fn finish_decoding(&mut self) -> Result { + track!(self.0.finish_decoding()).and_then($and_then) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl TryTaggedDecode for $decoder { + type Tag = AttributeType; + + fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result { + Ok(attr_type.as_u16() == $item::CODEPOINT) + } + } + }; +} + +macro_rules! impl_encode { + ($encoder:ty, $item:ty, $map_from:expr) => { + impl Encode for $encoder { + type Item = $item; + + fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result { + track!(self.0.encode(buf, eos)) + } + + fn start_encoding(&mut self, item: Self::Item) -> Result<()> { + track!(self.0.start_encoding($map_from(item).into())) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl SizedEncode for $encoder { + fn exact_requiring_bytes(&self) -> u64 { + self.0.exact_requiring_bytes() + } + } + }; +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct PeerAddressFamilyMismatch; @@ -19,3 +84,237 @@ impl From for ErrorCode { .expect("never fails") } } + +/// The family of an IP address, either IPv4 or IPv6. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AddressFamily { + V4, + V6, +} + +impl fmt::Display for AddressFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressFamily::V4 => write!(f, "IPv4"), + AddressFamily::V6 => write!(f, "IPv6"), + } + } +} + +const FAMILY_IPV4: u8 = 1; +const FAMILY_IPV6: u8 = 2; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RequestedAddressFamily(AddressFamily); +impl RequestedAddressFamily { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x0017; + + /// Makes a new `RequestedAddressFamily` instance. + pub fn new(fam: AddressFamily) -> Self { + RequestedAddressFamily(fam) + } + + /// Returns the requested address family. + pub fn address_family(&self) -> AddressFamily { + self.0 + } +} +impl Attribute for RequestedAddressFamily { + type Decoder = RequestedAddressFamilyDecoder; + type Encoder = RequestedAddressFamilyEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } +} + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct RequestedAddressFamilyDecoder(AddressFamilyDecoder); +impl RequestedAddressFamilyDecoder { + /// Makes a new `RequestedAddressFamilyDecoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_decode!( + RequestedAddressFamilyDecoder, + RequestedAddressFamily, + |item| Ok(RequestedAddressFamily(item)) +); + +/// [`RequestedAddressFamily`] encoder. +#[derive(Debug, Default)] +pub struct RequestedAddressFamilyEncoder(AddressFamilyEncoder); +impl RequestedAddressFamilyEncoder { + /// Makes a new `RequestedAddressFamilyEncoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_encode!( + RequestedAddressFamilyEncoder, + RequestedAddressFamily, + |item: Self::Item| { item.0 } +); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AdditionalAddressFamily(AddressFamily); +impl AdditionalAddressFamily { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x8000; + + /// Makes a new `AdditionalAddressFamily` instance. + pub fn new(fam: AddressFamily) -> Self { + AdditionalAddressFamily(fam) + } + + /// Returns the requested address family. + pub fn address_family(&self) -> AddressFamily { + self.0 + } +} +impl Attribute for AdditionalAddressFamily { + type Decoder = AdditionalAddressFamilyDecoder; + type Encoder = AdditionalAddressFamilyEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } +} + +/// [`AdditionalAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AdditionalAddressFamilyDecoder(AddressFamilyDecoder); + +impl_decode!( + AdditionalAddressFamilyDecoder, + AdditionalAddressFamily, + |item| Ok(AdditionalAddressFamily(item)) +); + +/// [`AdditionalAddressFamily`] encoder. +#[derive(Debug, Default)] +pub struct AdditionalAddressFamilyEncoder(AddressFamilyEncoder); +impl_encode!( + AdditionalAddressFamilyEncoder, + AdditionalAddressFamily, + |item: Self::Item| { item.0 } +); + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AddressFamilyDecoder { + family: U32beDecoder, +} + +impl Decode for AddressFamilyDecoder { + type Item = AddressFamily; + + fn decode(&mut self, buf: &[u8], eos: Eos) -> Result { + self.family.decode(buf, eos) + } + + fn finish_decoding(&mut self) -> Result { + let [fam, _, _, _] = self.family.finish_decoding()?.to_be_bytes(); + + match fam { + FAMILY_IPV4 => Ok(AddressFamily::V4), + FAMILY_IPV6 => Ok(AddressFamily::V6), + family => track_panic!( + ErrorKind::InvalidInput, + "Unknown address family: {}", + family + ), + } + } + + fn requiring_bytes(&self) -> ByteCount { + self.family.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.family.is_idle() + } +} + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AddressFamilyEncoder { + family: U32beEncoder, +} + +impl Encode for AddressFamilyEncoder { + type Item = AddressFamily; + + fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result { + self.family.encode(buf, eos) + } + + fn start_encoding(&mut self, item: Self::Item) -> Result<()> { + let fam_byte = match item { + AddressFamily::V4 => FAMILY_IPV4, + AddressFamily::V6 => FAMILY_IPV6, + }; + + let bytes = [fam_byte, 0, 0, 0]; + + self.family.start_encoding(u32::from_be_bytes(bytes)) + } + + fn requiring_bytes(&self) -> ByteCount { + ByteCount::Finite(self.exact_requiring_bytes()) + } +} + +impl SizedEncode for AddressFamilyEncoder { + fn exact_requiring_bytes(&self) -> u64 { + self.family.exact_requiring_bytes() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct AddressFamilyNotSupported; + +impl AddressFamilyNotSupported { + /// The codepoint of the error. + pub const CODEPOINT: u16 = 440; +} +impl From for ErrorCode { + fn from(_: AddressFamilyNotSupported) -> Self { + ErrorCode::new( + AddressFamilyNotSupported::CODEPOINT, + "Address Family not Supported".to_string(), + ) + .expect("never fails") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytecodec::{DecodeExt, EncodeExt}; + + #[test] + fn address_family_encoder_works() { + let mut encoder = AddressFamilyEncoder::default(); + + let bytes = encoder.encode_into_bytes(AddressFamily::V4).unwrap(); + assert_eq!(bytes, [1, 0, 0, 0]); + + let bytes = encoder.encode_into_bytes(AddressFamily::V6).unwrap(); + assert_eq!(bytes, [2, 0, 0, 0]); + } + + #[test] + fn address_family_decoder_works() { + let mut decoder = AddressFamilyDecoder::default(); + + let fam = decoder.decode_from_bytes(&[1, 0, 0, 0]).unwrap(); + assert_eq!(fam, AddressFamily::V4); + + let fam = decoder.decode_from_bytes(&[2, 0, 0, 0]).unwrap(); + assert_eq!(fam, AddressFamily::V6); + } +} diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 05c3dd4e8..43f8ec0f3 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -7,9 +7,13 @@ pub use crate::server::client_message::{ }; use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE}; -use crate::rfc8656::PeerAddressFamilyMismatch; +use crate::net_ext::IpAddrExt; +use crate::rfc8656::{ + AdditionalAddressFamily, AddressFamily, AddressFamilyNotSupported, PeerAddressFamilyMismatch, + RequestedAddressFamily, +}; use crate::stun_codec_ext::{MessageClassExt, MethodExt}; -use crate::TimeEvents; +use crate::{IpStack, TimeEvents}; use anyhow::Result; use bytecodec::EncodeExt; use core::fmt; @@ -20,7 +24,7 @@ use prometheus_client::registry::Registry; use rand::Rng; use std::collections::{HashMap, VecDeque}; use std::hash::Hash; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::{IpAddr, SocketAddr}; use std::time::{Duration, SystemTime}; use stun_codec::rfc5389::attributes::{ ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress, @@ -47,7 +51,7 @@ pub struct Server { decoder: client_message::Decoder, encoder: MessageEncoder, - public_ip4_address: Ipv4Addr, + public_address: IpStack, /// All client allocations, indexed by client's socket address. allocations: HashMap, @@ -82,12 +86,21 @@ pub enum Command { payload: Vec, recipient: SocketAddr, }, - /// Listen for traffic on the provided IP addresses. + /// Listen for traffic on the provided port [AddressFamily]. /// /// Any incoming data should be handed to the [`Server`] via [`Server::handle_relay_input`]. - AllocateAddresses { id: AllocationId, port: u16 }, - /// Free the addresses associated with the given [`AllocationId`]. - FreeAddresses { id: AllocationId }, + /// A single allocation can reference one of either [AddressFamily]s or both. + /// Only the combination of [AllocationId] and [AddressFamily] is unique. + CreateAllocation { + id: AllocationId, + family: AddressFamily, + port: u16, + }, + /// Free the allocation associated with the given [`AllocationId`] and [AddressFamily] + FreeAllocation { + id: AllocationId, + family: AddressFamily, + }, ForwardData { id: AllocationId, @@ -135,7 +148,7 @@ impl Server where R: Rng, { - pub fn new(public_ip4_address: Ipv4Addr, mut rng: R, registry: &mut Registry) -> Self { + pub fn new(public_address: impl Into, mut rng: R, registry: &mut Registry) -> Self { // TODO: Validate that local IP isn't multicast / loopback etc. let allocations_gauge = Gauge::default(); @@ -162,7 +175,7 @@ where Self { decoder: Default::default(), encoder: Default::default(), - public_ip4_address, + public_address: public_address.into(), allocations: Default::default(), clients_by_allocation: Default::default(), allocations_by_port: Default::default(), @@ -422,10 +435,22 @@ where return Err(error_response(BadRequest, &request)); } + let (first_relay_address, maybe_second_relay_addr) = derive_relay_addresses( + self.public_address, + request.requested_address_family(), + request.additional_address_family(), + ) + .map_err(|e| error_response(e, &request))?; + // TODO: Do we need to handle DONT-FRAGMENT? // TODO: Do we need to handle EVEN/ODD-PORT? - let allocation = self.create_new_allocation(now, &effective_lifetime); + let allocation = self.create_new_allocation( + now, + &effective_lifetime, + first_relay_address, + maybe_second_relay_addr, + ); let mut message = Message::new( MessageClass::SuccessResponse, @@ -433,9 +458,16 @@ where request.transaction_id(), ); - let ip4_relay_address = self.public_relay_address_for_port(allocation.port); + let port = allocation.port; + + message + .add_attribute(XorRelayAddress::new(SocketAddr::new(first_relay_address, port)).into()); + if let Some(second_relay_address) = maybe_second_relay_addr { + message.add_attribute( + XorRelayAddress::new(SocketAddr::new(second_relay_address, port)).into(), + ); + } - message.add_attribute(XorRelayAddress::new(ip4_relay_address.into()).into()); message.add_attribute(XorMappedAddress::new(sender).into()); message.add_attribute(effective_lifetime.clone().into()); @@ -446,17 +478,34 @@ where self.pending_commands.push_back(Command::Wake { deadline: wake_deadline, }); - self.pending_commands.push_back(Command::AllocateAddresses { + self.pending_commands.push_back(Command::CreateAllocation { id: allocation.id, - port: allocation.port, + family: first_relay_address.family(), + port, }); + if let Some(second_relay_addr) = maybe_second_relay_addr { + self.pending_commands.push_back(Command::CreateAllocation { + id: allocation.id, + family: second_relay_addr.family(), + port, + }); + } self.send_message(message, sender); - tracing::info!( - target: "relay", - ip4_relay_address = field::display(ip4_relay_address), - "Created new allocation", - ); + if let Some(second_relay_addr) = maybe_second_relay_addr { + tracing::info!( + target: "relay", + first_relay_address = field::display(first_relay_address), + second_relay_address = field::display(second_relay_addr), + "Created new allocation", + ) + } else { + tracing::info!( + target: "relay", + first_relay_address = field::display(first_relay_address), + "Created new allocation", + ) + } self.clients_by_allocation.insert(allocation.id, sender); self.allocations.insert(sender, allocation); @@ -693,7 +742,13 @@ where Ok(()) } - fn create_new_allocation(&mut self, now: SystemTime, lifetime: &Lifetime) -> Allocation { + fn create_new_allocation( + &mut self, + now: SystemTime, + lifetime: &Lifetime, + first_relay_addr: IpAddr, + second_relay_addr: Option, + ) -> Allocation { // First, find an unused port. assert!( @@ -718,6 +773,8 @@ where id, port, expires_at: now + lifetime.lifetime(), + first_relay_addr, + second_relay_addr, } } @@ -783,10 +840,6 @@ where .inc(); } - fn public_relay_address_for_port(&self, port: u16) -> SocketAddrV4 { - SocketAddrV4::new(self.public_ip4_address, port) - } - fn get_allocation(&self, id: &AllocationId) -> Option<&Allocation> { self.clients_by_allocation .get(id) @@ -805,9 +858,18 @@ where let port = allocation.port; self.allocations_by_port.remove(&port); + self.allocations_gauge.dec(); - self.pending_commands - .push_back(Command::FreeAddresses { id }); + self.pending_commands.push_back(Command::FreeAllocation { + id, + family: allocation.first_relay_addr.family(), + }); + if let Some(second_relay_addr) = allocation.second_relay_addr { + self.pending_commands.push_back(Command::FreeAllocation { + id, + family: second_relay_addr.family(), + }) + } tracing::info!(target: "relay", %port, "Deleted allocation"); } @@ -851,6 +913,9 @@ struct Allocation { /// Data arriving on this port will be forwarded to the client iff there is an active data channel. port: u16, expires_at: SystemTime, + + first_relay_addr: IpAddr, + second_relay_addr: Option, } struct Channel { @@ -919,6 +984,43 @@ fn error_response( message } +/// Derive the relay address for the client based on the request and the supported IP stack of the relay server. +/// +/// By default, a client gets an IPv4 address. +/// They can request an _additional_ IPv6 address or only an IPv6 address. +/// This is handled with two different STUN attributes: [AdditionalAddressFamily] and [RequestedAddressFamily]. +/// +/// The specification mandates certain checks for how these attributes can be used. +/// In a nutshell, the requirements constrain the use such that there is only one way of doing things. +/// For example, it is disallowed to use [RequestedAddressFamily] for IPv6 and requested and an IPv4 address via [AdditionalAddressFamily]. +/// If this is desired, clients should simply use [AdditionalAddressFamily] for IPv6. +fn derive_relay_addresses( + public_address: IpStack, + requested_addr_family: Option<&RequestedAddressFamily>, + additional_addr_family: Option<&AdditionalAddressFamily>, +) -> Result<(IpAddr, Option), ErrorCode> { + match ( + public_address, + requested_addr_family.map(|r| r.address_family()), + additional_addr_family.map(|a| a.address_family()), + ) { + ( + IpStack::Ip4(addr) | IpStack::Dual { ip4: addr, .. }, + None | Some(AddressFamily::V4), + None, + ) => Ok((addr.into(), None)), + (IpStack::Ip6(addr) | IpStack::Dual { ip6: addr, .. }, Some(AddressFamily::V6), None) => { + Ok((addr.into(), None)) + } + (IpStack::Dual { ip4, ip6 }, None, Some(AddressFamily::V6)) => { + Ok((ip4.into(), Some(ip6.into()))) + } + (_, Some(_), Some(_)) => Err(BadRequest.into()), + (_, _, Some(AddressFamily::V4)) => Err(BadRequest.into()), + _ => Err(AddressFamilyNotSupported.into()), + } +} + /// Private helper trait to make [`error_response`] more ergonomic to use. trait StunRequest { fn transaction_id(&self) -> TransactionId; @@ -990,7 +1092,9 @@ stun_codec::define_attribute_enums!( XorPeerAddress, Nonce, Realm, - Username + Username, + RequestedAddressFamily, + AdditionalAddressFamily ] ); @@ -1014,3 +1118,65 @@ enum MessageType { CreatePermission, Refresh, } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + // Tests for requirements listed in https://www.rfc-editor.org/rfc/rfc8656#name-receiving-an-allocate-reque. + + // 6. The server checks if the request contains both REQUESTED-ADDRESS-FAMILY and ADDITIONAL-ADDRESS-FAMILY attributes. If yes, then the server rejects the request with a 400 (Bad Request) error. + #[test] + fn requested_and_additional_is_bad_request() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V4)), + Some(&AdditionalAddressFamily::new(AddressFamily::V6)), + ) + .unwrap_err(); + + assert_eq!(error_code.code(), BadRequest::CODEPOINT) + } + + // 7. If the server does not support the address family requested by the client in REQUESTED-ADDRESS-FAMILY, or if the allocation of the requested address family is disabled by local policy, it MUST generate an Allocate error response, and it MUST include an ERROR-CODE attribute with the 440 (Address Family not Supported) response code. + // If the REQUESTED-ADDRESS-FAMILY attribute is absent and the server does not support the IPv4 address family, the server MUST include an ERROR-CODE attribute with the 440 (Address Family not Supported) response code. + #[test] + fn requested_address_family_not_available_is_not_supported() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V6)), + None, + ) + .unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT); + + let error_code = derive_relay_addresses( + IpStack::Ip6(Ipv6Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V4)), + None, + ) + .unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT); + + let error_code = + derive_relay_addresses(IpStack::Ip6(Ipv6Addr::LOCALHOST), None, None).unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT) + } + + //9. The server checks if the request contains an ADDITIONAL-ADDRESS-FAMILY attribute. If yes, and the attribute value is 0x01 (IPv4 address family), then the server rejects the request with a 400 (Bad Request) error. + #[test] + fn additional_address_family_ip4_is_bad_request() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + None, + Some(&AdditionalAddressFamily::new(AddressFamily::V4)), + ) + .unwrap_err(); + + assert_eq!(error_code.code(), BadRequest::CODEPOINT) + } +} diff --git a/rust/relay/src/server/client_message.rs b/rust/relay/src/server/client_message.rs index b7c1e309c..acb04ebf6 100644 --- a/rust/relay/src/server/client_message.rs +++ b/rust/relay/src/server/client_message.rs @@ -1,4 +1,5 @@ use crate::auth::{generate_password, split_username, systemtime_from_unix, FIREZONE}; +use crate::rfc8656::{AdditionalAddressFamily, AddressFamily, RequestedAddressFamily}; use crate::server::channel_data::ChannelData; use crate::server::UDP_TRANSPORT; use crate::Attribute; @@ -121,37 +122,26 @@ pub struct Allocate { lifetime: Option, username: Option, nonce: Option, + requested_address_family: Option, + additional_address_family: Option, } impl Allocate { - pub fn new_authenticated_udp( + pub fn new_authenticated_udp_implicit_ip4( transaction_id: TransactionId, lifetime: Option, username: Username, relay_secret: &str, nonce: Uuid, ) -> Self { - let requested_transport = RequestedTransport::new(UDP_TRANSPORT); - let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128"); - - let mut message = - Message::::new(MessageClass::Request, ALLOCATE, transaction_id); - message.add_attribute(requested_transport.clone().into()); - message.add_attribute(username.clone().into()); - message.add_attribute(nonce.clone().into()); - - if let Some(lifetime) = &lifetime { - message.add_attribute(lifetime.clone().into()); - } - - let (expiry, salt) = split_username(username.name()).expect("a valid username"); - let expiry_systemtime = systemtime_from_unix(expiry); - - let password = generate_password(relay_secret, expiry_systemtime, salt); - - let message_integrity = - MessageIntegrity::new_long_term_credential(&message, &username, &FIREZONE, &password) - .unwrap(); + let (requested_transport, nonce, message_integrity) = Self::make_attributes( + transaction_id, + &lifetime, + &username, + relay_secret, + nonce, + None, + ); Self { transaction_id, @@ -160,6 +150,38 @@ impl Allocate { lifetime, username: Some(username), nonce: Some(nonce), + requested_address_family: None, // IPv4 is the default. + additional_address_family: None, + } + } + + pub fn new_authenticated_udp_ip6( + transaction_id: TransactionId, + lifetime: Option, + username: Username, + relay_secret: &str, + nonce: Uuid, + ) -> Self { + let requested_address_family = RequestedAddressFamily::new(AddressFamily::V6); + + let (requested_transport, nonce, message_integrity) = Self::make_attributes( + transaction_id, + &lifetime, + &username, + relay_secret, + nonce, + Some(requested_address_family.clone()), + ); + + Self { + transaction_id, + message_integrity: Some(message_integrity), + requested_transport, + lifetime, + username: Some(username), + nonce: Some(nonce), + requested_address_family: Some(requested_address_family), + additional_address_family: None, } } @@ -184,9 +206,47 @@ impl Allocate { lifetime, username: None, nonce: None, + requested_address_family: None, + additional_address_family: None, } } + fn make_attributes( + transaction_id: TransactionId, + lifetime: &Option, + username: &Username, + relay_secret: &str, + nonce: Uuid, + requested_address_family: Option, + ) -> (RequestedTransport, Nonce, MessageIntegrity) { + let requested_transport = RequestedTransport::new(UDP_TRANSPORT); + let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128"); + + let mut message = + Message::::new(MessageClass::Request, ALLOCATE, transaction_id); + message.add_attribute(requested_transport.clone().into()); + message.add_attribute(username.clone().into()); + message.add_attribute(nonce.clone().into()); + + if let Some(requested_address_family) = requested_address_family { + message.add_attribute(requested_address_family.into()); + } + + if let Some(lifetime) = &lifetime { + message.add_attribute(lifetime.clone().into()); + } + + let (expiry, salt) = split_username(username.name()).expect("a valid username"); + let expiry_systemtime = systemtime_from_unix(expiry); + + let password = generate_password(relay_secret, expiry_systemtime, salt); + + let message_integrity = + MessageIntegrity::new_long_term_credential(&message, username, &FIREZONE, &password) + .unwrap(); + (requested_transport, nonce, message_integrity) + } + pub fn parse(message: &Message) -> Result> { let transaction_id = message.transaction_id(); let message_integrity = message.get_attribute::().cloned(); @@ -197,6 +257,8 @@ impl Allocate { .clone(); let lifetime = message.get_attribute::().cloned(); let username = message.get_attribute::().cloned(); + let requested_address_family = message.get_attribute::().cloned(); + let additional_address_family = message.get_attribute::().cloned(); Ok(Allocate { transaction_id, @@ -205,6 +267,8 @@ impl Allocate { lifetime, username, nonce, + requested_address_family, + additional_address_family, }) } @@ -231,6 +295,14 @@ impl Allocate { pub fn nonce(&self) -> Option<&Nonce> { self.nonce.as_ref() } + + pub fn requested_address_family(&self) -> Option<&RequestedAddressFamily> { + self.requested_address_family.as_ref() + } + + pub fn additional_address_family(&self) -> Option<&AdditionalAddressFamily> { + self.additional_address_family.as_ref() + } } pub struct Refresh { diff --git a/rust/relay/src/udp_socket.rs b/rust/relay/src/udp_socket.rs index 8359d781f..92cd74070 100644 --- a/rust/relay/src/udp_socket.rs +++ b/rust/relay/src/udp_socket.rs @@ -1,4 +1,5 @@ -use anyhow::Result; +use crate::{AddressFamily, SocketAddrExt}; +use anyhow::{Context as _, Result}; use std::net::SocketAddr; use std::task::{ready, Context, Poll}; use tokio::io::ReadBuf; @@ -12,9 +13,13 @@ pub struct UdpSocket { } impl UdpSocket { - pub async fn bind(addr: impl Into) -> Result { + pub fn bind(addr: impl Into) -> Result { + let addr = addr.into(); + let std_socket = make_std_socket(addr) + .with_context(|| format!("Failed to bind UDP socket to {addr}"))?; + Ok(Self { - inner: tokio::net::UdpSocket::bind(addr.into()).await?, + inner: tokio::net::UdpSocket::from_std(std_socket)?, recv_buf: [0u8; MAX_UDP_SIZE], }) } @@ -52,3 +57,25 @@ impl UdpSocket { Poll::Ready(Ok(())) } } + +/// Creates an [std::net::UdpSocket] via the [socket2] library that is configured for our needs. +/// +/// Most importantly, this sets the `IPV6_V6ONLY` flag to ensure we disallow IP4-mapped IPv6 addresses and can bind to IP4 and IP6 addresses on the same port. +fn make_std_socket(socket_addr: SocketAddr) -> Result { + use socket2::*; + + let domain = match socket_addr.family() { + AddressFamily::V4 => Domain::IPV4, + AddressFamily::V6 => Domain::IPV6, + }; + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + if socket_addr.is_ipv6() { + socket.set_only_v6(true)?; + } + + socket.set_nonblocking(true)?; + socket.bind(&socket_addr.into())?; + + Ok(socket.into()) +} diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index af14298c4..6436a7fcd 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -2,12 +2,12 @@ use bytecodec::{DecodeExt, EncodeExt}; use prometheus_client::registry::Registry; use rand::rngs::mock::StepRng; use relay::{ - Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, - Refresh, Server, + AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, + ClientMessage, Command, IpStack, Refresh, Server, }; use std::collections::HashMap; use std::iter; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, SystemTime}; use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress}; use stun_codec::rfc5389::errors::Unauthorized; @@ -55,7 +55,7 @@ fn deallocate_once_time_expired( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -66,7 +66,7 @@ fn deallocate_once_time_expired( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response(transaction_id, public_relay_addr, 49152, source, &lifetime), @@ -76,7 +76,7 @@ fn deallocate_once_time_expired( server.assert_commands( forward_time_to(now + lifetime.lifetime() + Duration::from_secs(1)), - [FreeAllocation(49152)], + [FreeAllocation(49152, AddressFamily::V4)], ); } @@ -110,7 +110,7 @@ fn unauthenticated_allocate_triggers_authentication( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -121,7 +121,7 @@ fn unauthenticated_allocate_triggers_authentication( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response(transaction_id, public_relay_addr, 49152, source, &lifetime), @@ -149,7 +149,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), valid_username(now, &username_salt), @@ -160,7 +160,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( ), [ Wake(first_wake), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -225,7 +225,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), valid_username(now, &username_salt), @@ -236,7 +236,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( ), [ Wake(first_wake), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -266,7 +266,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( now, ), [ - FreeAllocation(49152), + FreeAllocation(49152, AddressFamily::V4), send_message( source, refresh_response( @@ -309,7 +309,7 @@ fn ping_pong_relay( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -320,7 +320,7 @@ fn ping_pong_relay( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -375,13 +375,57 @@ fn ping_pong_relay( ); } +#[proptest] +fn can_make_ipv6_allocation( + #[strategy(relay::proptest::transaction_id())] transaction_id: TransactionId, + #[strategy(relay::proptest::allocation_lifetime())] lifetime: Lifetime, + #[strategy(relay::proptest::username_salt())] username_salt: String, + source: SocketAddrV4, + public_relay_ip4_addr: Ipv4Addr, + public_relay_ip6_addr: Ipv6Addr, + #[strategy(relay::proptest::now())] now: SystemTime, + #[strategy(relay::proptest::nonce())] nonce: Uuid, +) { + let mut server = + TestServer::new((public_relay_ip4_addr, public_relay_ip6_addr)).with_nonce(nonce); + let secret = server.auth_secret(); + + server.assert_commands( + from_client( + source, + Allocate::new_authenticated_udp_ip6( + transaction_id, + Some(lifetime.clone()), + valid_username(now, &username_salt), + secret, + nonce, + ), + now, + ), + [ + Wake(now + lifetime.lifetime()), + CreateAllocation(49152, AddressFamily::V6), + send_message( + source, + allocate_response( + transaction_id, + public_relay_ip6_addr, + 49152, + source, + &lifetime, + ), + ), + ], + ); +} + struct TestServer { server: Server, id_to_port: HashMap, } impl TestServer { - fn new(relay_public_addr: Ipv4Addr) -> Self { + fn new(relay_public_addr: impl Into) -> Self { Self { server: Server::new( relay_public_addr, @@ -421,8 +465,8 @@ impl TestServer { let msg = match expected_output { Output::SendMessage((recipient, msg)) => format!("to send message {:?} to {recipient}", msg), Wake(time) => format!("to be woken at {time:?}"), - CreateAllocation(port) => format!("to create allocation on port {port}"), - FreeAllocation(port) => format!("to free allocation on port {port}"), + CreateAllocation(port, family) => format!("to create allocation on port {port} for address family {family}"), + FreeAllocation(port, family) => format!("to free allocation on port {port} for address family {family}"), Output::SendChannelData((peer, _)) => format!("to send channel data from {peer} to client"), Output::Forward((peer, _, _)) => format!("to forward data to peer {peer}") }; @@ -452,18 +496,27 @@ impl TestServer { assert_eq!(when, deadline); } ( - CreateAllocation(expected_port), - Command::AllocateAddresses { + CreateAllocation(expected_port, expected_family), + Command::CreateAllocation { id, + family: actual_family, port: actual_port, }, ) => { self.id_to_port.insert(actual_port, id); assert_eq!(expected_port, actual_port); + assert_eq!(expected_family, actual_family); } - (FreeAllocation(port), Command::FreeAddresses { id }) => { + ( + FreeAllocation(port, family), + Command::FreeAllocation { + id, + family: actual_family, + }, + ) => { let actual_id = self.id_to_port.remove(&port).expect("to have port in map"); assert_eq!(id, actual_id); + assert_eq!(family, actual_family); } (Wake(when), Command::SendMessage { payload, .. }) => { panic!( @@ -530,7 +583,7 @@ fn binding_response( fn allocate_response( transaction_id: TransactionId, - public_relay_addr: Ipv4Addr, + public_relay_addr: impl Into, port: u16, source: SocketAddrV4, lifetime: &Lifetime, @@ -538,7 +591,7 @@ fn allocate_response( let mut message = Message::::new(MessageClass::SuccessResponse, ALLOCATE, transaction_id); message.add_attribute( - XorRelayAddress::new(SocketAddrV4::new(public_relay_addr, port).into()).into(), + XorRelayAddress::new(SocketAddr::new(public_relay_addr.into(), port)).into(), ); message.add_attribute(XorMappedAddress::new(source.into()).into()); message.add_attribute(lifetime.clone().into()); @@ -610,8 +663,8 @@ enum Output<'a> { SendChannelData((SocketAddr, ChannelData<'a>)), Forward((SocketAddr, Vec, u16)), Wake(SystemTime), - CreateAllocation(u16), - FreeAllocation(u16), + CreateAllocation(u16, AddressFamily), + FreeAllocation(u16, AddressFamily), } fn send_message<'a>(source: impl Into, message: Message) -> Output<'a> {