diff --git a/rust/Cargo.lock b/rust/Cargo.lock index ef019af07..e48f26521 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1511,7 +1511,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -2548,12 +2548,14 @@ dependencies = [ "redis", "serde", "sha2", + "socket2 0.5.3", "stun_codec", "test-strategy", "tokio", "tracing", "tracing-stackdriver", "tracing-subscriber", + "trackable 1.3.0", "url", "uuid", "webrtc", @@ -2977,6 +2979,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -3294,7 +3306,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.4.9", "tokio-macros", "windows-sys 0.48.0", ] @@ -3917,7 +3929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f08dfd7a6e3987e255c4dbe710dde5d94d0f0574f8a21afa95d171376c143106" dependencies = [ "log", - "socket2", + "socket2 0.4.9", "thiserror", "tokio", "webrtc-util", diff --git a/rust/relay/Cargo.toml b/rust/relay/Cargo.toml index 9e029582e..8ff309a0c 100644 --- a/rust/relay/Cargo.toml +++ b/rust/relay/Cargo.toml @@ -28,6 +28,8 @@ uuid = { version = "1.4.1", features = ["v4"] } phoenix-channel = { path = "../phoenix-channel" } url = "2.4.0" serde = { version = "1.0.179", features = ["derive"] } +trackable = "1.3.0" +socket2 = "0.5.3" prometheus-client = "0.21.1" axum = { version = "0.6.18", default-features = false, features = ["http1", "tokio"] } diff --git a/rust/relay/src/allocation.rs b/rust/relay/src/allocation.rs index fadaf8094..bb527ac22 100644 --- a/rust/relay/src/allocation.rs +++ b/rust/relay/src/allocation.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; use std::convert::Infallible; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, SocketAddr}; use tokio::task; /// The maximum amount of items that can be buffered in the channel to the allocation task. @@ -24,17 +24,17 @@ impl Allocation { pub fn new( relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, id: AllocationId, - listen_ip4_addr: Ipv4Addr, + listen_addr: IpAddr, port: u16, ) -> Self { let (client_to_peer_sender, client_to_peer_receiver) = mpsc::channel(MAX_BUFFERED_ITEMS); let task = tokio::spawn(async move { - let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_ip4_addr, port).await else { + let Err(e) = forward_incoming_relay_data(relay_data_sender, client_to_peer_receiver, id, listen_addr, port).await else { unreachable!() }; - tracing::warn!("Allocation task for {id} failed: {e}"); + tracing::warn!(allocation = %id, %listen_addr, "Allocation task failed: {e:#}"); // With the task stopping, the channel will be closed and any attempt to send data to it will fail. }); @@ -83,10 +83,10 @@ async fn forward_incoming_relay_data( mut relayed_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, mut client_to_peer_receiver: mpsc::Receiver<(Vec, SocketAddr)>, id: AllocationId, - listen_ip4_addr: Ipv4Addr, + listen_addr: IpAddr, port: u16, ) -> Result { - let mut socket = UdpSocket::bind((listen_ip4_addr, port)).await?; + let mut socket = UdpSocket::bind((listen_addr, port))?; loop { tokio::select! { diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index b6dfeab43..8450c0b93 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -1,5 +1,6 @@ mod allocation; mod auth; +mod net_ext; mod rfc8656; mod server; mod sleep; @@ -12,6 +13,8 @@ pub mod metrics; pub mod proptest; pub use allocation::Allocation; +pub use net_ext::{IpAddrExt, SocketAddrExt}; +pub use rfc8656::AddressFamily; pub use server::{ Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, CreatePermission, Refresh, Server, @@ -20,3 +23,49 @@ pub use sleep::Sleep; pub use udp_socket::UdpSocket; pub(crate) use time_events::TimeEvents; + +use std::net::{Ipv4Addr, Ipv6Addr}; + +/// Describes the IP stack of a relay server. +#[derive(Debug, Copy, Clone)] +pub enum IpStack { + Ip4(Ipv4Addr), + Ip6(Ipv6Addr), + Dual { ip4: Ipv4Addr, ip6: Ipv6Addr }, +} + +impl IpStack { + pub fn as_v4(&self) -> Option<&Ipv4Addr> { + match self { + IpStack::Ip4(ip4) => Some(ip4), + IpStack::Ip6(_) => None, + IpStack::Dual { ip4, .. } => Some(ip4), + } + } + + pub fn as_v6(&self) -> Option<&Ipv6Addr> { + match self { + IpStack::Ip4(_) => None, + IpStack::Ip6(ip6) => Some(ip6), + IpStack::Dual { ip6, .. } => Some(ip6), + } + } +} + +impl From for IpStack { + fn from(value: Ipv4Addr) -> Self { + IpStack::Ip4(value) + } +} + +impl From for IpStack { + fn from(value: Ipv6Addr) -> Self { + IpStack::Ip6(value) + } +} + +impl From<(Ipv4Addr, Ipv6Addr)> for IpStack { + fn from((ip4, ip6): (Ipv4Addr, Ipv6Addr)) -> Self { + IpStack::Dual { ip4, ip6 } + } +} diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 58d8cb15e..e1ae2625e 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -6,11 +6,14 @@ use phoenix_channel::{Error, Event, PhoenixChannel}; use prometheus_client::registry::Registry; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use relay::{Allocation, AllocationId, Command, Server, Sleep, UdpSocket}; +use relay::{ + AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt, + UdpSocket, +}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::convert::Infallible; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; use std::time::SystemTime; @@ -24,14 +27,24 @@ use url::Url; struct Args { /// The public (i.e. internet-reachable) IPv4 address of the relay server. /// - /// Must route to the local interface we listen on. + /// Must route to the local IPv4 interface we listen on. #[arg(long, env)] - public_ip4_addr: Ipv4Addr, - /// The address of the local interface we should listen on. + public_ip4_addr: Option, + /// The address of the local IPv4 interface we should listen on. /// /// Must not be a wildcard-address. #[arg(long, env)] - listen_ip4_addr: Ipv4Addr, + listen_ip4_addr: Option, + /// The public (i.e. internet-reachable) IPv6 address of the relay server. + /// + /// Must route to the local IP6 interface we listen on. + #[arg(long, env)] + public_ip6_addr: Option, + /// The address of the local IP6 interface we should listen on. + /// + /// Must not be a wildcard-address. + #[arg(long, env)] + listen_ip6_addr: Option, /// The address of the local interface where we should serve the prometheus metrics. /// /// The metrics will be available at `http:///metrics`. @@ -74,13 +87,25 @@ async fn main() -> Result<()> { tracing_subscriber::fmt().with_env_filter(env_filter).init() } - let mut metric_registry = Registry::with_prefix("relay"); + let listen_addr = match (args.listen_ip4_addr, args.listen_ip6_addr) { + (Some(ip4), Some(ip6)) => IpStack::Dual { ip4, ip6 }, + (Some(ip4), None) => IpStack::Ip4(ip4), + (None, Some(ip6)) => IpStack::Ip6(ip6), + (None, None) => { + bail!("Must listen on at least one of IPv4 or IPv6") + } + }; + let public_addr = match (args.public_ip4_addr, args.public_ip6_addr, listen_addr) { + (Some(ip4), Some(ip6), IpStack::Dual { .. }) => IpStack::Dual { ip4, ip6 }, + (Some(ip4), None, IpStack::Ip4(_)) => IpStack::Ip4(ip4), + (None, Some(ip6), IpStack::Ip6(_)) => IpStack::Ip6(ip6), + _ => { + bail!("Must specify a public address for each listen address") + } + }; - let server = Server::new( - args.public_ip4_addr, - make_rng(args.rng_seed), - &mut metric_registry, - ); + let mut metric_registry = Registry::with_prefix("relay"); + let server = Server::new(public_addr, make_rng(args.rng_seed), &mut metric_registry); let channel = if let Some(token) = args.portal_token { let mut url = args.portal_ws_url.clone(); @@ -92,9 +117,16 @@ async fn main() -> Result<()> { } url.set_path("relay/websocket"); - url.query_pairs_mut() - .append_pair("token", &token) - .append_pair("ipv4", &args.listen_ip4_addr.to_string()); + url.query_pairs_mut().append_pair("token", &token); + + if let Some(listen_ip4_addr) = args.listen_ip4_addr { + url.query_pairs_mut() + .append_pair("ipv4", &listen_ip4_addr.to_string()); + } + if let Some(listen_ip6_addr) = args.listen_ip6_addr { + url.query_pairs_mut() + .append_pair("ipv6", &listen_ip6_addr.to_string()); + } let mut channel = PhoenixChannel::::connect( url, @@ -137,8 +169,7 @@ async fn main() -> Result<()> { None }; - let mut eventloop = - Eventloop::new(server, channel, args.listen_ip4_addr, &mut metric_registry).await?; + let mut eventloop = Eventloop::new(server, channel, listen_addr, &mut metric_registry)?; if let Some(metrics_addr) = args.metrics_addr { tokio::spawn(relay::metrics::serve(metrics_addr, metric_registry)); @@ -186,11 +217,12 @@ fn make_rng(seed: Option) -> StdRng { struct Eventloop { inbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, - outbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, - listen_ip4_address: Ipv4Addr, + outbound_ip4_data_sender: mpsc::Sender<(Vec, SocketAddr)>, + outbound_ip6_data_sender: mpsc::Sender<(Vec, SocketAddr)>, + listen_address: IpStack, server: Server, channel: Option>, - allocations: HashMap, + allocations: HashMap<(AllocationId, AddressFamily), Allocation>, relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, sleep: Sleep, @@ -200,27 +232,39 @@ impl Eventloop where R: Rng, { - async fn new( + fn new( server: Server, channel: Option>, - listen_ip4_address: Ipv4Addr, + listen_address: IpStack, _: &mut Registry, ) -> Result { let (relay_data_sender, relay_data_receiver) = mpsc::channel(1); let (inbound_data_sender, inbound_data_receiver) = mpsc::channel(10); - let (outbound_data_sender, outbound_data_receiver) = + let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = + mpsc::channel::<(Vec, SocketAddr)>(10); + let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = mpsc::channel::<(Vec, SocketAddr)>(10); - tokio::spawn(main_udp_socket_task( - listen_ip4_address, - inbound_data_sender, - outbound_data_receiver, - )); + if let Some(ip4) = listen_address.as_v4() { + tokio::spawn(main_udp_socket_task( + (*ip4).into(), + inbound_data_sender.clone(), + outbound_ip4_data_receiver, + )); + } + if let Some(ip6) = listen_address.as_v6() { + tokio::spawn(main_udp_socket_task( + (*ip6).into(), + inbound_data_sender, + outbound_ip6_data_receiver, + )); + } Ok(Self { inbound_data_receiver, - outbound_data_sender, - listen_ip4_address, + outbound_ip4_data_sender, + outbound_ip6_data_sender, + listen_address, server, channel, allocations: Default::default(), @@ -236,7 +280,12 @@ where if let Some(next_command) = self.server.next_command() { match next_command { Command::SendMessage { payload, recipient } => { - if let Err(e) = self.outbound_data_sender.try_send((payload, recipient)) { + let sender = match recipient.family() { + AddressFamily::V4 => &mut self.outbound_ip4_data_sender, + AddressFamily::V6 => &mut self.outbound_ip6_data_sender, + }; + + if let Err(e) = sender.try_send((payload, recipient)) { if e.is_disconnected() { return Poll::Ready(Err(anyhow!( "Channel to primary UDP socket task has been closed" @@ -248,19 +297,19 @@ where } } } - Command::AllocateAddresses { id, port } => { + Command::CreateAllocation { id, family, port } => { + let listen_addr = match family { + AddressFamily::V4 => (*self.listen_address.as_v4().expect("to have listen address for IP4 if we are creating an IP4 allocation")).into(), + AddressFamily::V6 => (*self.listen_address.as_v6().expect("to have listen address for IP6 if we are creating an IP6 allocation")).into(), + }; + self.allocations.insert( - id, - Allocation::new( - self.relay_data_sender.clone(), - id, - self.listen_ip4_address, - port, - ), + (id, family), + Allocation::new(self.relay_data_sender.clone(), id, listen_addr, port), ); } - Command::FreeAddresses { id } => { - if self.allocations.remove(&id).is_none() { + Command::FreeAllocation { id, family } => { + if self.allocations.remove(&(id, family)).is_none() { tracing::debug!("Unknown allocation {id}"); continue; }; @@ -271,10 +320,10 @@ where Pin::new(&mut self.sleep).reset(deadline); } Command::ForwardData { id, data, receiver } => { - let mut allocation = match self.allocations.entry(id) { + let mut allocation = match self.allocations.entry((id, receiver.family())) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => { - tracing::debug!(allocation = %id, "Unknown allocation"); + tracing::debug!(allocation = %id, family = %receiver.family(), "Unknown allocation"); continue; } }; @@ -362,11 +411,11 @@ where } async fn main_udp_socket_task( - listen_ip4_address: Ipv4Addr, + listen_address: IpAddr, mut inbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, mut outbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, ) -> Result { - let mut socket = UdpSocket::bind((listen_ip4_address, 3478)).await?; + let mut socket = UdpSocket::bind((listen_address, 3478))?; loop { tokio::select! { diff --git a/rust/relay/src/net_ext.rs b/rust/relay/src/net_ext.rs new file mode 100644 index 000000000..8807a2664 --- /dev/null +++ b/rust/relay/src/net_ext.rs @@ -0,0 +1,28 @@ +use crate::rfc8656::AddressFamily; +use std::net::{IpAddr, SocketAddr}; + +pub trait IpAddrExt { + fn family(&self) -> AddressFamily; +} + +impl IpAddrExt for IpAddr { + fn family(&self) -> AddressFamily { + match self { + IpAddr::V4(_) => AddressFamily::V4, + IpAddr::V6(_) => AddressFamily::V6, + } + } +} + +pub trait SocketAddrExt { + fn family(&self) -> AddressFamily; +} + +impl SocketAddrExt for SocketAddr { + fn family(&self) -> AddressFamily { + match self { + SocketAddr::V4(_) => AddressFamily::V4, + SocketAddr::V6(_) => AddressFamily::V6, + } + } +} diff --git a/rust/relay/src/rfc8656.rs b/rust/relay/src/rfc8656.rs index 293a81213..aa613be55 100644 --- a/rust/relay/src/rfc8656.rs +++ b/rust/relay/src/rfc8656.rs @@ -2,7 +2,72 @@ // // TODO: Upstream this to `stun-codec`. +use bytecodec::fixnum::{U32beDecoder, U32beEncoder}; +use bytecodec::{ByteCount, Decode, Encode, Eos, ErrorKind, Result, SizedEncode, TryTaggedDecode}; +use std::fmt; use stun_codec::rfc5389::attributes::ErrorCode; +use stun_codec::{Attribute, AttributeType}; +use trackable::{track, track_panic}; + +macro_rules! impl_decode { + ($decoder:ty, $item:ident, $and_then:expr) => { + impl Decode for $decoder { + type Item = $item; + + fn decode(&mut self, buf: &[u8], eos: Eos) -> Result { + track!(self.0.decode(buf, eos)) + } + + fn finish_decoding(&mut self) -> Result { + track!(self.0.finish_decoding()).and_then($and_then) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl TryTaggedDecode for $decoder { + type Tag = AttributeType; + + fn try_start_decoding(&mut self, attr_type: Self::Tag) -> Result { + Ok(attr_type.as_u16() == $item::CODEPOINT) + } + } + }; +} + +macro_rules! impl_encode { + ($encoder:ty, $item:ty, $map_from:expr) => { + impl Encode for $encoder { + type Item = $item; + + fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result { + track!(self.0.encode(buf, eos)) + } + + fn start_encoding(&mut self, item: Self::Item) -> Result<()> { + track!(self.0.start_encoding($map_from(item).into())) + } + + fn requiring_bytes(&self) -> ByteCount { + self.0.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.0.is_idle() + } + } + impl SizedEncode for $encoder { + fn exact_requiring_bytes(&self) -> u64 { + self.0.exact_requiring_bytes() + } + } + }; +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct PeerAddressFamilyMismatch; @@ -19,3 +84,237 @@ impl From for ErrorCode { .expect("never fails") } } + +/// The family of an IP address, either IPv4 or IPv6. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AddressFamily { + V4, + V6, +} + +impl fmt::Display for AddressFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressFamily::V4 => write!(f, "IPv4"), + AddressFamily::V6 => write!(f, "IPv6"), + } + } +} + +const FAMILY_IPV4: u8 = 1; +const FAMILY_IPV6: u8 = 2; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RequestedAddressFamily(AddressFamily); +impl RequestedAddressFamily { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x0017; + + /// Makes a new `RequestedAddressFamily` instance. + pub fn new(fam: AddressFamily) -> Self { + RequestedAddressFamily(fam) + } + + /// Returns the requested address family. + pub fn address_family(&self) -> AddressFamily { + self.0 + } +} +impl Attribute for RequestedAddressFamily { + type Decoder = RequestedAddressFamilyDecoder; + type Encoder = RequestedAddressFamilyEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } +} + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct RequestedAddressFamilyDecoder(AddressFamilyDecoder); +impl RequestedAddressFamilyDecoder { + /// Makes a new `RequestedAddressFamilyDecoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_decode!( + RequestedAddressFamilyDecoder, + RequestedAddressFamily, + |item| Ok(RequestedAddressFamily(item)) +); + +/// [`RequestedAddressFamily`] encoder. +#[derive(Debug, Default)] +pub struct RequestedAddressFamilyEncoder(AddressFamilyEncoder); +impl RequestedAddressFamilyEncoder { + /// Makes a new `RequestedAddressFamilyEncoder` instance. + pub fn new() -> Self { + Self::default() + } +} +impl_encode!( + RequestedAddressFamilyEncoder, + RequestedAddressFamily, + |item: Self::Item| { item.0 } +); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AdditionalAddressFamily(AddressFamily); +impl AdditionalAddressFamily { + /// The codepoint of the type of the attribute. + pub const CODEPOINT: u16 = 0x8000; + + /// Makes a new `AdditionalAddressFamily` instance. + pub fn new(fam: AddressFamily) -> Self { + AdditionalAddressFamily(fam) + } + + /// Returns the requested address family. + pub fn address_family(&self) -> AddressFamily { + self.0 + } +} +impl Attribute for AdditionalAddressFamily { + type Decoder = AdditionalAddressFamilyDecoder; + type Encoder = AdditionalAddressFamilyEncoder; + + fn get_type(&self) -> AttributeType { + AttributeType::new(Self::CODEPOINT) + } +} + +/// [`AdditionalAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AdditionalAddressFamilyDecoder(AddressFamilyDecoder); + +impl_decode!( + AdditionalAddressFamilyDecoder, + AdditionalAddressFamily, + |item| Ok(AdditionalAddressFamily(item)) +); + +/// [`AdditionalAddressFamily`] encoder. +#[derive(Debug, Default)] +pub struct AdditionalAddressFamilyEncoder(AddressFamilyEncoder); +impl_encode!( + AdditionalAddressFamilyEncoder, + AdditionalAddressFamily, + |item: Self::Item| { item.0 } +); + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AddressFamilyDecoder { + family: U32beDecoder, +} + +impl Decode for AddressFamilyDecoder { + type Item = AddressFamily; + + fn decode(&mut self, buf: &[u8], eos: Eos) -> Result { + self.family.decode(buf, eos) + } + + fn finish_decoding(&mut self) -> Result { + let [fam, _, _, _] = self.family.finish_decoding()?.to_be_bytes(); + + match fam { + FAMILY_IPV4 => Ok(AddressFamily::V4), + FAMILY_IPV6 => Ok(AddressFamily::V6), + family => track_panic!( + ErrorKind::InvalidInput, + "Unknown address family: {}", + family + ), + } + } + + fn requiring_bytes(&self) -> ByteCount { + self.family.requiring_bytes() + } + + fn is_idle(&self) -> bool { + self.family.is_idle() + } +} + +/// [`RequestedAddressFamily`] decoder. +#[derive(Debug, Default)] +pub struct AddressFamilyEncoder { + family: U32beEncoder, +} + +impl Encode for AddressFamilyEncoder { + type Item = AddressFamily; + + fn encode(&mut self, buf: &mut [u8], eos: Eos) -> Result { + self.family.encode(buf, eos) + } + + fn start_encoding(&mut self, item: Self::Item) -> Result<()> { + let fam_byte = match item { + AddressFamily::V4 => FAMILY_IPV4, + AddressFamily::V6 => FAMILY_IPV6, + }; + + let bytes = [fam_byte, 0, 0, 0]; + + self.family.start_encoding(u32::from_be_bytes(bytes)) + } + + fn requiring_bytes(&self) -> ByteCount { + ByteCount::Finite(self.exact_requiring_bytes()) + } +} + +impl SizedEncode for AddressFamilyEncoder { + fn exact_requiring_bytes(&self) -> u64 { + self.family.exact_requiring_bytes() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct AddressFamilyNotSupported; + +impl AddressFamilyNotSupported { + /// The codepoint of the error. + pub const CODEPOINT: u16 = 440; +} +impl From for ErrorCode { + fn from(_: AddressFamilyNotSupported) -> Self { + ErrorCode::new( + AddressFamilyNotSupported::CODEPOINT, + "Address Family not Supported".to_string(), + ) + .expect("never fails") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytecodec::{DecodeExt, EncodeExt}; + + #[test] + fn address_family_encoder_works() { + let mut encoder = AddressFamilyEncoder::default(); + + let bytes = encoder.encode_into_bytes(AddressFamily::V4).unwrap(); + assert_eq!(bytes, [1, 0, 0, 0]); + + let bytes = encoder.encode_into_bytes(AddressFamily::V6).unwrap(); + assert_eq!(bytes, [2, 0, 0, 0]); + } + + #[test] + fn address_family_decoder_works() { + let mut decoder = AddressFamilyDecoder::default(); + + let fam = decoder.decode_from_bytes(&[1, 0, 0, 0]).unwrap(); + assert_eq!(fam, AddressFamily::V4); + + let fam = decoder.decode_from_bytes(&[2, 0, 0, 0]).unwrap(); + assert_eq!(fam, AddressFamily::V6); + } +} diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 05c3dd4e8..43f8ec0f3 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -7,9 +7,13 @@ pub use crate::server::client_message::{ }; use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE}; -use crate::rfc8656::PeerAddressFamilyMismatch; +use crate::net_ext::IpAddrExt; +use crate::rfc8656::{ + AdditionalAddressFamily, AddressFamily, AddressFamilyNotSupported, PeerAddressFamilyMismatch, + RequestedAddressFamily, +}; use crate::stun_codec_ext::{MessageClassExt, MethodExt}; -use crate::TimeEvents; +use crate::{IpStack, TimeEvents}; use anyhow::Result; use bytecodec::EncodeExt; use core::fmt; @@ -20,7 +24,7 @@ use prometheus_client::registry::Registry; use rand::Rng; use std::collections::{HashMap, VecDeque}; use std::hash::Hash; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::{IpAddr, SocketAddr}; use std::time::{Duration, SystemTime}; use stun_codec::rfc5389::attributes::{ ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress, @@ -47,7 +51,7 @@ pub struct Server { decoder: client_message::Decoder, encoder: MessageEncoder, - public_ip4_address: Ipv4Addr, + public_address: IpStack, /// All client allocations, indexed by client's socket address. allocations: HashMap, @@ -82,12 +86,21 @@ pub enum Command { payload: Vec, recipient: SocketAddr, }, - /// Listen for traffic on the provided IP addresses. + /// Listen for traffic on the provided port [AddressFamily]. /// /// Any incoming data should be handed to the [`Server`] via [`Server::handle_relay_input`]. - AllocateAddresses { id: AllocationId, port: u16 }, - /// Free the addresses associated with the given [`AllocationId`]. - FreeAddresses { id: AllocationId }, + /// A single allocation can reference one of either [AddressFamily]s or both. + /// Only the combination of [AllocationId] and [AddressFamily] is unique. + CreateAllocation { + id: AllocationId, + family: AddressFamily, + port: u16, + }, + /// Free the allocation associated with the given [`AllocationId`] and [AddressFamily] + FreeAllocation { + id: AllocationId, + family: AddressFamily, + }, ForwardData { id: AllocationId, @@ -135,7 +148,7 @@ impl Server where R: Rng, { - pub fn new(public_ip4_address: Ipv4Addr, mut rng: R, registry: &mut Registry) -> Self { + pub fn new(public_address: impl Into, mut rng: R, registry: &mut Registry) -> Self { // TODO: Validate that local IP isn't multicast / loopback etc. let allocations_gauge = Gauge::default(); @@ -162,7 +175,7 @@ where Self { decoder: Default::default(), encoder: Default::default(), - public_ip4_address, + public_address: public_address.into(), allocations: Default::default(), clients_by_allocation: Default::default(), allocations_by_port: Default::default(), @@ -422,10 +435,22 @@ where return Err(error_response(BadRequest, &request)); } + let (first_relay_address, maybe_second_relay_addr) = derive_relay_addresses( + self.public_address, + request.requested_address_family(), + request.additional_address_family(), + ) + .map_err(|e| error_response(e, &request))?; + // TODO: Do we need to handle DONT-FRAGMENT? // TODO: Do we need to handle EVEN/ODD-PORT? - let allocation = self.create_new_allocation(now, &effective_lifetime); + let allocation = self.create_new_allocation( + now, + &effective_lifetime, + first_relay_address, + maybe_second_relay_addr, + ); let mut message = Message::new( MessageClass::SuccessResponse, @@ -433,9 +458,16 @@ where request.transaction_id(), ); - let ip4_relay_address = self.public_relay_address_for_port(allocation.port); + let port = allocation.port; + + message + .add_attribute(XorRelayAddress::new(SocketAddr::new(first_relay_address, port)).into()); + if let Some(second_relay_address) = maybe_second_relay_addr { + message.add_attribute( + XorRelayAddress::new(SocketAddr::new(second_relay_address, port)).into(), + ); + } - message.add_attribute(XorRelayAddress::new(ip4_relay_address.into()).into()); message.add_attribute(XorMappedAddress::new(sender).into()); message.add_attribute(effective_lifetime.clone().into()); @@ -446,17 +478,34 @@ where self.pending_commands.push_back(Command::Wake { deadline: wake_deadline, }); - self.pending_commands.push_back(Command::AllocateAddresses { + self.pending_commands.push_back(Command::CreateAllocation { id: allocation.id, - port: allocation.port, + family: first_relay_address.family(), + port, }); + if let Some(second_relay_addr) = maybe_second_relay_addr { + self.pending_commands.push_back(Command::CreateAllocation { + id: allocation.id, + family: second_relay_addr.family(), + port, + }); + } self.send_message(message, sender); - tracing::info!( - target: "relay", - ip4_relay_address = field::display(ip4_relay_address), - "Created new allocation", - ); + if let Some(second_relay_addr) = maybe_second_relay_addr { + tracing::info!( + target: "relay", + first_relay_address = field::display(first_relay_address), + second_relay_address = field::display(second_relay_addr), + "Created new allocation", + ) + } else { + tracing::info!( + target: "relay", + first_relay_address = field::display(first_relay_address), + "Created new allocation", + ) + } self.clients_by_allocation.insert(allocation.id, sender); self.allocations.insert(sender, allocation); @@ -693,7 +742,13 @@ where Ok(()) } - fn create_new_allocation(&mut self, now: SystemTime, lifetime: &Lifetime) -> Allocation { + fn create_new_allocation( + &mut self, + now: SystemTime, + lifetime: &Lifetime, + first_relay_addr: IpAddr, + second_relay_addr: Option, + ) -> Allocation { // First, find an unused port. assert!( @@ -718,6 +773,8 @@ where id, port, expires_at: now + lifetime.lifetime(), + first_relay_addr, + second_relay_addr, } } @@ -783,10 +840,6 @@ where .inc(); } - fn public_relay_address_for_port(&self, port: u16) -> SocketAddrV4 { - SocketAddrV4::new(self.public_ip4_address, port) - } - fn get_allocation(&self, id: &AllocationId) -> Option<&Allocation> { self.clients_by_allocation .get(id) @@ -805,9 +858,18 @@ where let port = allocation.port; self.allocations_by_port.remove(&port); + self.allocations_gauge.dec(); - self.pending_commands - .push_back(Command::FreeAddresses { id }); + self.pending_commands.push_back(Command::FreeAllocation { + id, + family: allocation.first_relay_addr.family(), + }); + if let Some(second_relay_addr) = allocation.second_relay_addr { + self.pending_commands.push_back(Command::FreeAllocation { + id, + family: second_relay_addr.family(), + }) + } tracing::info!(target: "relay", %port, "Deleted allocation"); } @@ -851,6 +913,9 @@ struct Allocation { /// Data arriving on this port will be forwarded to the client iff there is an active data channel. port: u16, expires_at: SystemTime, + + first_relay_addr: IpAddr, + second_relay_addr: Option, } struct Channel { @@ -919,6 +984,43 @@ fn error_response( message } +/// Derive the relay address for the client based on the request and the supported IP stack of the relay server. +/// +/// By default, a client gets an IPv4 address. +/// They can request an _additional_ IPv6 address or only an IPv6 address. +/// This is handled with two different STUN attributes: [AdditionalAddressFamily] and [RequestedAddressFamily]. +/// +/// The specification mandates certain checks for how these attributes can be used. +/// In a nutshell, the requirements constrain the use such that there is only one way of doing things. +/// For example, it is disallowed to use [RequestedAddressFamily] for IPv6 and requested and an IPv4 address via [AdditionalAddressFamily]. +/// If this is desired, clients should simply use [AdditionalAddressFamily] for IPv6. +fn derive_relay_addresses( + public_address: IpStack, + requested_addr_family: Option<&RequestedAddressFamily>, + additional_addr_family: Option<&AdditionalAddressFamily>, +) -> Result<(IpAddr, Option), ErrorCode> { + match ( + public_address, + requested_addr_family.map(|r| r.address_family()), + additional_addr_family.map(|a| a.address_family()), + ) { + ( + IpStack::Ip4(addr) | IpStack::Dual { ip4: addr, .. }, + None | Some(AddressFamily::V4), + None, + ) => Ok((addr.into(), None)), + (IpStack::Ip6(addr) | IpStack::Dual { ip6: addr, .. }, Some(AddressFamily::V6), None) => { + Ok((addr.into(), None)) + } + (IpStack::Dual { ip4, ip6 }, None, Some(AddressFamily::V6)) => { + Ok((ip4.into(), Some(ip6.into()))) + } + (_, Some(_), Some(_)) => Err(BadRequest.into()), + (_, _, Some(AddressFamily::V4)) => Err(BadRequest.into()), + _ => Err(AddressFamilyNotSupported.into()), + } +} + /// Private helper trait to make [`error_response`] more ergonomic to use. trait StunRequest { fn transaction_id(&self) -> TransactionId; @@ -990,7 +1092,9 @@ stun_codec::define_attribute_enums!( XorPeerAddress, Nonce, Realm, - Username + Username, + RequestedAddressFamily, + AdditionalAddressFamily ] ); @@ -1014,3 +1118,65 @@ enum MessageType { CreatePermission, Refresh, } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; + + // Tests for requirements listed in https://www.rfc-editor.org/rfc/rfc8656#name-receiving-an-allocate-reque. + + // 6. The server checks if the request contains both REQUESTED-ADDRESS-FAMILY and ADDITIONAL-ADDRESS-FAMILY attributes. If yes, then the server rejects the request with a 400 (Bad Request) error. + #[test] + fn requested_and_additional_is_bad_request() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V4)), + Some(&AdditionalAddressFamily::new(AddressFamily::V6)), + ) + .unwrap_err(); + + assert_eq!(error_code.code(), BadRequest::CODEPOINT) + } + + // 7. If the server does not support the address family requested by the client in REQUESTED-ADDRESS-FAMILY, or if the allocation of the requested address family is disabled by local policy, it MUST generate an Allocate error response, and it MUST include an ERROR-CODE attribute with the 440 (Address Family not Supported) response code. + // If the REQUESTED-ADDRESS-FAMILY attribute is absent and the server does not support the IPv4 address family, the server MUST include an ERROR-CODE attribute with the 440 (Address Family not Supported) response code. + #[test] + fn requested_address_family_not_available_is_not_supported() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V6)), + None, + ) + .unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT); + + let error_code = derive_relay_addresses( + IpStack::Ip6(Ipv6Addr::LOCALHOST), + Some(&RequestedAddressFamily::new(AddressFamily::V4)), + None, + ) + .unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT); + + let error_code = + derive_relay_addresses(IpStack::Ip6(Ipv6Addr::LOCALHOST), None, None).unwrap_err(); + + assert_eq!(error_code.code(), AddressFamilyNotSupported::CODEPOINT) + } + + //9. The server checks if the request contains an ADDITIONAL-ADDRESS-FAMILY attribute. If yes, and the attribute value is 0x01 (IPv4 address family), then the server rejects the request with a 400 (Bad Request) error. + #[test] + fn additional_address_family_ip4_is_bad_request() { + let error_code = derive_relay_addresses( + IpStack::Ip4(Ipv4Addr::LOCALHOST), + None, + Some(&AdditionalAddressFamily::new(AddressFamily::V4)), + ) + .unwrap_err(); + + assert_eq!(error_code.code(), BadRequest::CODEPOINT) + } +} diff --git a/rust/relay/src/server/client_message.rs b/rust/relay/src/server/client_message.rs index b7c1e309c..acb04ebf6 100644 --- a/rust/relay/src/server/client_message.rs +++ b/rust/relay/src/server/client_message.rs @@ -1,4 +1,5 @@ use crate::auth::{generate_password, split_username, systemtime_from_unix, FIREZONE}; +use crate::rfc8656::{AdditionalAddressFamily, AddressFamily, RequestedAddressFamily}; use crate::server::channel_data::ChannelData; use crate::server::UDP_TRANSPORT; use crate::Attribute; @@ -121,37 +122,26 @@ pub struct Allocate { lifetime: Option, username: Option, nonce: Option, + requested_address_family: Option, + additional_address_family: Option, } impl Allocate { - pub fn new_authenticated_udp( + pub fn new_authenticated_udp_implicit_ip4( transaction_id: TransactionId, lifetime: Option, username: Username, relay_secret: &str, nonce: Uuid, ) -> Self { - let requested_transport = RequestedTransport::new(UDP_TRANSPORT); - let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128"); - - let mut message = - Message::::new(MessageClass::Request, ALLOCATE, transaction_id); - message.add_attribute(requested_transport.clone().into()); - message.add_attribute(username.clone().into()); - message.add_attribute(nonce.clone().into()); - - if let Some(lifetime) = &lifetime { - message.add_attribute(lifetime.clone().into()); - } - - let (expiry, salt) = split_username(username.name()).expect("a valid username"); - let expiry_systemtime = systemtime_from_unix(expiry); - - let password = generate_password(relay_secret, expiry_systemtime, salt); - - let message_integrity = - MessageIntegrity::new_long_term_credential(&message, &username, &FIREZONE, &password) - .unwrap(); + let (requested_transport, nonce, message_integrity) = Self::make_attributes( + transaction_id, + &lifetime, + &username, + relay_secret, + nonce, + None, + ); Self { transaction_id, @@ -160,6 +150,38 @@ impl Allocate { lifetime, username: Some(username), nonce: Some(nonce), + requested_address_family: None, // IPv4 is the default. + additional_address_family: None, + } + } + + pub fn new_authenticated_udp_ip6( + transaction_id: TransactionId, + lifetime: Option, + username: Username, + relay_secret: &str, + nonce: Uuid, + ) -> Self { + let requested_address_family = RequestedAddressFamily::new(AddressFamily::V6); + + let (requested_transport, nonce, message_integrity) = Self::make_attributes( + transaction_id, + &lifetime, + &username, + relay_secret, + nonce, + Some(requested_address_family.clone()), + ); + + Self { + transaction_id, + message_integrity: Some(message_integrity), + requested_transport, + lifetime, + username: Some(username), + nonce: Some(nonce), + requested_address_family: Some(requested_address_family), + additional_address_family: None, } } @@ -184,9 +206,47 @@ impl Allocate { lifetime, username: None, nonce: None, + requested_address_family: None, + additional_address_family: None, } } + fn make_attributes( + transaction_id: TransactionId, + lifetime: &Option, + username: &Username, + relay_secret: &str, + nonce: Uuid, + requested_address_family: Option, + ) -> (RequestedTransport, Nonce, MessageIntegrity) { + let requested_transport = RequestedTransport::new(UDP_TRANSPORT); + let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128"); + + let mut message = + Message::::new(MessageClass::Request, ALLOCATE, transaction_id); + message.add_attribute(requested_transport.clone().into()); + message.add_attribute(username.clone().into()); + message.add_attribute(nonce.clone().into()); + + if let Some(requested_address_family) = requested_address_family { + message.add_attribute(requested_address_family.into()); + } + + if let Some(lifetime) = &lifetime { + message.add_attribute(lifetime.clone().into()); + } + + let (expiry, salt) = split_username(username.name()).expect("a valid username"); + let expiry_systemtime = systemtime_from_unix(expiry); + + let password = generate_password(relay_secret, expiry_systemtime, salt); + + let message_integrity = + MessageIntegrity::new_long_term_credential(&message, username, &FIREZONE, &password) + .unwrap(); + (requested_transport, nonce, message_integrity) + } + pub fn parse(message: &Message) -> Result> { let transaction_id = message.transaction_id(); let message_integrity = message.get_attribute::().cloned(); @@ -197,6 +257,8 @@ impl Allocate { .clone(); let lifetime = message.get_attribute::().cloned(); let username = message.get_attribute::().cloned(); + let requested_address_family = message.get_attribute::().cloned(); + let additional_address_family = message.get_attribute::().cloned(); Ok(Allocate { transaction_id, @@ -205,6 +267,8 @@ impl Allocate { lifetime, username, nonce, + requested_address_family, + additional_address_family, }) } @@ -231,6 +295,14 @@ impl Allocate { pub fn nonce(&self) -> Option<&Nonce> { self.nonce.as_ref() } + + pub fn requested_address_family(&self) -> Option<&RequestedAddressFamily> { + self.requested_address_family.as_ref() + } + + pub fn additional_address_family(&self) -> Option<&AdditionalAddressFamily> { + self.additional_address_family.as_ref() + } } pub struct Refresh { diff --git a/rust/relay/src/udp_socket.rs b/rust/relay/src/udp_socket.rs index 8359d781f..92cd74070 100644 --- a/rust/relay/src/udp_socket.rs +++ b/rust/relay/src/udp_socket.rs @@ -1,4 +1,5 @@ -use anyhow::Result; +use crate::{AddressFamily, SocketAddrExt}; +use anyhow::{Context as _, Result}; use std::net::SocketAddr; use std::task::{ready, Context, Poll}; use tokio::io::ReadBuf; @@ -12,9 +13,13 @@ pub struct UdpSocket { } impl UdpSocket { - pub async fn bind(addr: impl Into) -> Result { + pub fn bind(addr: impl Into) -> Result { + let addr = addr.into(); + let std_socket = make_std_socket(addr) + .with_context(|| format!("Failed to bind UDP socket to {addr}"))?; + Ok(Self { - inner: tokio::net::UdpSocket::bind(addr.into()).await?, + inner: tokio::net::UdpSocket::from_std(std_socket)?, recv_buf: [0u8; MAX_UDP_SIZE], }) } @@ -52,3 +57,25 @@ impl UdpSocket { Poll::Ready(Ok(())) } } + +/// Creates an [std::net::UdpSocket] via the [socket2] library that is configured for our needs. +/// +/// Most importantly, this sets the `IPV6_V6ONLY` flag to ensure we disallow IP4-mapped IPv6 addresses and can bind to IP4 and IP6 addresses on the same port. +fn make_std_socket(socket_addr: SocketAddr) -> Result { + use socket2::*; + + let domain = match socket_addr.family() { + AddressFamily::V4 => Domain::IPV4, + AddressFamily::V6 => Domain::IPV6, + }; + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + if socket_addr.is_ipv6() { + socket.set_only_v6(true)?; + } + + socket.set_nonblocking(true)?; + socket.bind(&socket_addr.into())?; + + Ok(socket.into()) +} diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index af14298c4..6436a7fcd 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -2,12 +2,12 @@ use bytecodec::{DecodeExt, EncodeExt}; use prometheus_client::registry::Registry; use rand::rngs::mock::StepRng; use relay::{ - Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, - Refresh, Server, + AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, + ClientMessage, Command, IpStack, Refresh, Server, }; use std::collections::HashMap; use std::iter; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, SystemTime}; use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress}; use stun_codec::rfc5389::errors::Unauthorized; @@ -55,7 +55,7 @@ fn deallocate_once_time_expired( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -66,7 +66,7 @@ fn deallocate_once_time_expired( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response(transaction_id, public_relay_addr, 49152, source, &lifetime), @@ -76,7 +76,7 @@ fn deallocate_once_time_expired( server.assert_commands( forward_time_to(now + lifetime.lifetime() + Duration::from_secs(1)), - [FreeAllocation(49152)], + [FreeAllocation(49152, AddressFamily::V4)], ); } @@ -110,7 +110,7 @@ fn unauthenticated_allocate_triggers_authentication( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -121,7 +121,7 @@ fn unauthenticated_allocate_triggers_authentication( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response(transaction_id, public_relay_addr, 49152, source, &lifetime), @@ -149,7 +149,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), valid_username(now, &username_salt), @@ -160,7 +160,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( ), [ Wake(first_wake), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -225,7 +225,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), valid_username(now, &username_salt), @@ -236,7 +236,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( ), [ Wake(first_wake), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -266,7 +266,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( now, ), [ - FreeAllocation(49152), + FreeAllocation(49152, AddressFamily::V4), send_message( source, refresh_response( @@ -309,7 +309,7 @@ fn ping_pong_relay( server.assert_commands( from_client( source, - Allocate::new_authenticated_udp( + Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(lifetime.clone()), valid_username(now, &username_salt), @@ -320,7 +320,7 @@ fn ping_pong_relay( ), [ Wake(now + lifetime.lifetime()), - CreateAllocation(49152), + CreateAllocation(49152, AddressFamily::V4), send_message( source, allocate_response( @@ -375,13 +375,57 @@ fn ping_pong_relay( ); } +#[proptest] +fn can_make_ipv6_allocation( + #[strategy(relay::proptest::transaction_id())] transaction_id: TransactionId, + #[strategy(relay::proptest::allocation_lifetime())] lifetime: Lifetime, + #[strategy(relay::proptest::username_salt())] username_salt: String, + source: SocketAddrV4, + public_relay_ip4_addr: Ipv4Addr, + public_relay_ip6_addr: Ipv6Addr, + #[strategy(relay::proptest::now())] now: SystemTime, + #[strategy(relay::proptest::nonce())] nonce: Uuid, +) { + let mut server = + TestServer::new((public_relay_ip4_addr, public_relay_ip6_addr)).with_nonce(nonce); + let secret = server.auth_secret(); + + server.assert_commands( + from_client( + source, + Allocate::new_authenticated_udp_ip6( + transaction_id, + Some(lifetime.clone()), + valid_username(now, &username_salt), + secret, + nonce, + ), + now, + ), + [ + Wake(now + lifetime.lifetime()), + CreateAllocation(49152, AddressFamily::V6), + send_message( + source, + allocate_response( + transaction_id, + public_relay_ip6_addr, + 49152, + source, + &lifetime, + ), + ), + ], + ); +} + struct TestServer { server: Server, id_to_port: HashMap, } impl TestServer { - fn new(relay_public_addr: Ipv4Addr) -> Self { + fn new(relay_public_addr: impl Into) -> Self { Self { server: Server::new( relay_public_addr, @@ -421,8 +465,8 @@ impl TestServer { let msg = match expected_output { Output::SendMessage((recipient, msg)) => format!("to send message {:?} to {recipient}", msg), Wake(time) => format!("to be woken at {time:?}"), - CreateAllocation(port) => format!("to create allocation on port {port}"), - FreeAllocation(port) => format!("to free allocation on port {port}"), + CreateAllocation(port, family) => format!("to create allocation on port {port} for address family {family}"), + FreeAllocation(port, family) => format!("to free allocation on port {port} for address family {family}"), Output::SendChannelData((peer, _)) => format!("to send channel data from {peer} to client"), Output::Forward((peer, _, _)) => format!("to forward data to peer {peer}") }; @@ -452,18 +496,27 @@ impl TestServer { assert_eq!(when, deadline); } ( - CreateAllocation(expected_port), - Command::AllocateAddresses { + CreateAllocation(expected_port, expected_family), + Command::CreateAllocation { id, + family: actual_family, port: actual_port, }, ) => { self.id_to_port.insert(actual_port, id); assert_eq!(expected_port, actual_port); + assert_eq!(expected_family, actual_family); } - (FreeAllocation(port), Command::FreeAddresses { id }) => { + ( + FreeAllocation(port, family), + Command::FreeAllocation { + id, + family: actual_family, + }, + ) => { let actual_id = self.id_to_port.remove(&port).expect("to have port in map"); assert_eq!(id, actual_id); + assert_eq!(family, actual_family); } (Wake(when), Command::SendMessage { payload, .. }) => { panic!( @@ -530,7 +583,7 @@ fn binding_response( fn allocate_response( transaction_id: TransactionId, - public_relay_addr: Ipv4Addr, + public_relay_addr: impl Into, port: u16, source: SocketAddrV4, lifetime: &Lifetime, @@ -538,7 +591,7 @@ fn allocate_response( let mut message = Message::::new(MessageClass::SuccessResponse, ALLOCATE, transaction_id); message.add_attribute( - XorRelayAddress::new(SocketAddrV4::new(public_relay_addr, port).into()).into(), + XorRelayAddress::new(SocketAddr::new(public_relay_addr.into(), port)).into(), ); message.add_attribute(XorMappedAddress::new(source.into()).into()); message.add_attribute(lifetime.clone().into()); @@ -610,8 +663,8 @@ enum Output<'a> { SendChannelData((SocketAddr, ChannelData<'a>)), Forward((SocketAddr, Vec, u16)), Wake(SystemTime), - CreateAllocation(u16), - FreeAllocation(u16), + CreateAllocation(u16, AddressFamily), + FreeAllocation(u16, AddressFamily), } fn send_message<'a>(source: impl Into, message: Message) -> Output<'a> {