diff --git a/rust/relay/src/allocation.rs b/rust/relay/src/allocation.rs index cb6cf7a18..9132aef93 100644 --- a/rust/relay/src/allocation.rs +++ b/rust/relay/src/allocation.rs @@ -1,11 +1,10 @@ use crate::server::AllocationId; use crate::udp_socket::UdpSocket; -use crate::AddressFamily; +use crate::{AddressFamily, PeerSocket}; use anyhow::{bail, Result}; use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; use std::convert::Infallible; -use std::net::SocketAddr; use tokio::task; /// The maximum amount of items that can be buffered in the channel to the allocation task. @@ -18,12 +17,12 @@ pub struct Allocation { /// /// Stored here to make resource-cleanup easy. handle: task::JoinHandle<()>, - sender: mpsc::Sender<(Vec, SocketAddr)>, + sender: mpsc::Sender<(Vec, PeerSocket)>, } impl Allocation { pub fn new( - relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, + relay_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, id: AllocationId, family: AddressFamily, port: u16, @@ -62,7 +61,7 @@ impl Allocation { /// /// All our data is relayed over UDP which by design is an unreliable protocol. /// Thus, any application running on top of this relay must already account for potential packet loss. - pub fn send(&mut self, data: Vec, recipient: SocketAddr) -> Result<()> { + pub fn send(&mut self, data: Vec, recipient: PeerSocket) -> Result<()> { match self.sender.try_send((data, recipient)) { Ok(()) => Ok(()), Err(e) if e.is_disconnected() => { @@ -89,8 +88,8 @@ impl Drop for Allocation { } async fn forward_incoming_relay_data( - mut relayed_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, - mut client_to_peer_receiver: mpsc::Receiver<(Vec, SocketAddr)>, + mut relayed_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, + mut client_to_peer_receiver: mpsc::Receiver<(Vec, PeerSocket)>, id: AllocationId, family: AddressFamily, port: u16, @@ -101,11 +100,11 @@ async fn forward_incoming_relay_data( tokio::select! { result = socket.recv() => { let (data, sender) = result?; - relayed_data_sender.send((data.to_vec(), sender, id)).await?; + relayed_data_sender.send((data.to_vec(), PeerSocket::new(sender), id)).await?; } Some((data, recipient)) = client_to_peer_receiver.next() => { - socket.send_to(&data, recipient).await?; + socket.send_to(&data, recipient.into_socket()).await?; } } } diff --git a/rust/relay/src/lib.rs b/rust/relay/src/lib.rs index ba096dd54..2c283e28d 100644 --- a/rust/relay/src/lib.rs +++ b/rust/relay/src/lib.rs @@ -11,7 +11,7 @@ pub mod health_check; pub mod proptest; pub use allocation::Allocation; -pub use net_ext::{IpAddrExt, SocketAddrExt}; +pub use net_ext::IpAddrExt; pub use server::{ Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command, CreatePermission, Refresh, Server, @@ -22,7 +22,10 @@ pub use udp_socket::UdpSocket; pub(crate) use time_events::TimeEvents; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::{ + fmt, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, +}; /// Describes the IP stack of a relay server. #[derive(Debug, Copy, Clone)] @@ -67,3 +70,65 @@ impl From<(Ipv4Addr, Ipv6Addr)> for IpStack { IpStack::Dual { ip4, ip6 } } } + +/// New-type for a client's socket. +/// +/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.4): +/// +/// > A STUN client that implements this specification. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct ClientSocket(SocketAddr); + +impl ClientSocket { + pub fn new(addr: SocketAddr) -> Self { + Self(addr) + } + + pub fn into_socket(self) -> SocketAddr { + self.0 + } + + pub fn family(&self) -> AddressFamily { + match self.0 { + SocketAddr::V4(_) => AddressFamily::V4, + SocketAddr::V6(_) => AddressFamily::V6, + } + } +} + +impl fmt::Display for ClientSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +/// New-type for a peer's socket. +/// +/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.8): +/// +/// > A host with which the TURN client wishes to communicate. The TURN server relays traffic between the TURN client and its peer(s). The peer does not interact with the TURN server using the protocol defined in this document; rather, the peer receives data sent by the TURN server, and the peer sends data towards the TURN server. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct PeerSocket(SocketAddr); + +impl PeerSocket { + pub fn new(addr: SocketAddr) -> Self { + Self(addr) + } + + pub fn family(&self) -> AddressFamily { + match self.0 { + SocketAddr::V4(_) => AddressFamily::V4, + SocketAddr::V6(_) => AddressFamily::V6, + } + } + + pub fn into_socket(self) -> SocketAddr { + self.0 + } +} + +impl fmt::Display for PeerSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 6592e544e..b092c2854 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -2,8 +2,8 @@ use anyhow::{anyhow, bail, Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; use firezone_relay::{ - AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt, - UdpSocket, + AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, Server, + Sleep, UdpSocket, }; use futures::channel::mpsc; use futures::{future, FutureExt, SinkExt, StreamExt}; @@ -313,14 +313,14 @@ fn make_rng(seed: Option) -> StdRng { } struct Eventloop { - inbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, - outbound_ip4_data_sender: mpsc::Sender<(Vec, SocketAddr)>, - outbound_ip6_data_sender: mpsc::Sender<(Vec, SocketAddr)>, + inbound_data_receiver: mpsc::Receiver<(Vec, ClientSocket)>, + outbound_ip4_data_sender: mpsc::Sender<(Vec, ClientSocket)>, + outbound_ip6_data_sender: mpsc::Sender<(Vec, ClientSocket)>, server: Server, channel: Option>, allocations: HashMap<(AllocationId, AddressFamily), Allocation>, - relay_data_sender: mpsc::Sender<(Vec, SocketAddr, AllocationId)>, - relay_data_receiver: mpsc::Receiver<(Vec, SocketAddr, AllocationId)>, + relay_data_sender: mpsc::Sender<(Vec, PeerSocket, AllocationId)>, + relay_data_receiver: mpsc::Receiver<(Vec, PeerSocket, AllocationId)>, sleep: Sleep, stats_log_interval: tokio::time::Interval, @@ -338,10 +338,8 @@ where ) -> Result { let (relay_data_sender, relay_data_receiver) = mpsc::channel(1); let (inbound_data_sender, inbound_data_receiver) = mpsc::channel(1000); - let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = - mpsc::channel::<(Vec, SocketAddr)>(1000); - let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = - mpsc::channel::<(Vec, SocketAddr)>(1000); + let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = mpsc::channel(1000); + let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = mpsc::channel(1000); if public_address.as_v4().is_some() { tokio::spawn(main_udp_socket_task( @@ -563,8 +561,8 @@ fn fmt_human_throughput(mut throughput: f64) -> String { async fn main_udp_socket_task( family: AddressFamily, - mut inbound_data_sender: mpsc::Sender<(Vec, SocketAddr)>, - mut outbound_data_receiver: mpsc::Receiver<(Vec, SocketAddr)>, + mut inbound_data_sender: mpsc::Sender<(Vec, ClientSocket)>, + mut outbound_data_receiver: mpsc::Receiver<(Vec, ClientSocket)>, ) -> Result { let mut socket = UdpSocket::bind(family, 3478)?; @@ -572,11 +570,11 @@ async fn main_udp_socket_task( tokio::select! { result = socket.recv() => { let (data, sender) = result?; - inbound_data_sender.send((data.to_vec(), sender)).await?; + inbound_data_sender.send((data.to_vec(), ClientSocket::new(sender))).await?; } maybe_item = outbound_data_receiver.next() => { let (data, recipient) = maybe_item.context("Outbound data channel closed")?; - socket.send_to(data.as_ref(), recipient).await?; + socket.send_to(data.as_ref(), recipient.into_socket()).await?; } } } diff --git a/rust/relay/src/net_ext.rs b/rust/relay/src/net_ext.rs index 09df59293..139bd208d 100644 --- a/rust/relay/src/net_ext.rs +++ b/rust/relay/src/net_ext.rs @@ -1,4 +1,4 @@ -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use stun_codec::rfc8656::attributes::AddressFamily; pub trait IpAddrExt { @@ -13,16 +13,3 @@ impl IpAddrExt for IpAddr { } } } - -pub trait SocketAddrExt { - fn family(&self) -> AddressFamily; -} - -impl SocketAddrExt for SocketAddr { - fn family(&self) -> AddressFamily { - match self { - SocketAddr::V4(_) => AddressFamily::V4, - SocketAddr::V6(_) => AddressFamily::V6, - } - } -} diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 6e813a478..af3e89b15 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -8,7 +8,7 @@ pub use crate::server::client_message::{ use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE}; use crate::net_ext::IpAddrExt; -use crate::{IpStack, TimeEvents}; +use crate::{ClientSocket, IpStack, PeerSocket, TimeEvents}; use anyhow::Result; use bytecodec::EncodeExt; use core::fmt; @@ -53,15 +53,17 @@ pub struct Server { public_address: IpStack, /// All client allocations, indexed by client's socket address. - allocations: HashMap, - clients_by_allocation: HashMap, + allocations: HashMap, + clients_by_allocation: HashMap, allocations_by_port: HashMap, lowest_port: u16, highest_port: u16, - channels_by_client_and_number: HashMap<(SocketAddr, u16), Channel>, - channel_numbers_by_peer: HashMap, + /// Channel numbers are unique by client, thus indexed by both. + channels_by_client_and_number: HashMap<(ClientSocket, u16), Channel>, + /// Channel numbers are unique between clients and peers, thus indexed by both. + channel_numbers_by_client_and_peer: HashMap<(ClientSocket, PeerSocket), u16>, pending_commands: VecDeque, next_allocation_id: AllocationId, @@ -87,7 +89,7 @@ pub struct Server { pub enum Command { SendMessage { payload: Vec, - recipient: SocketAddr, + recipient: ClientSocket, }, /// Listen for traffic on the provided port [AddressFamily]. /// @@ -108,7 +110,7 @@ pub enum Command { ForwardData { id: AllocationId, data: Vec, - receiver: SocketAddr, + receiver: PeerSocket, }, /// At the latest, the [`Server`] needs to be woken at the specified deadline to execute time-based actions correctly. Wake { deadline: SystemTime }, @@ -187,7 +189,7 @@ where lowest_port, highest_port, channels_by_client_and_number: Default::default(), - channel_numbers_by_peer: Default::default(), + channel_numbers_by_client_and_peer: Default::default(), pending_commands: Default::default(), next_allocation_id: AllocationId(1), auth_secret: SecretString::from(hex::encode(rng.gen::<[u8; 32]>())), @@ -228,7 +230,7 @@ where /// /// After calling this method, you should call [`Server::next_command`] until it returns `None`. #[tracing::instrument(skip_all, fields(transaction_id, %sender, allocation, channel, recipient, peer), level = "error")] - pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: SystemTime) { + pub fn handle_client_input(&mut self, bytes: &[u8], sender: ClientSocket, now: SystemTime) { if tracing::enabled!(target: "wire", tracing::Level::TRACE) { let hex_bytes = hex::encode(bytes); tracing::trace!(target: "wire", %hex_bytes, "receiving bytes"); @@ -265,7 +267,7 @@ where pub fn handle_client_message( &mut self, message: ClientMessage, - sender: SocketAddr, + sender: ClientSocket, now: SystemTime, ) { let result = match message { @@ -294,7 +296,11 @@ where self.queue_error_response(sender, error_response) } - fn queue_error_response(&mut self, sender: SocketAddr, mut error_response: Message) { + fn queue_error_response( + &mut self, + sender: ClientSocket, + mut error_response: Message, + ) { let Some(error) = error_response.get_attribute::().cloned() else { debug_assert!(false, "Error response without an `ErrorCode`"); return; @@ -322,7 +328,7 @@ where pub fn handle_peer_traffic( &mut self, bytes: &[u8], - sender: SocketAddr, + sender: PeerSocket, allocation: AllocationId, ) { if tracing::enabled!(target: "wire", tracing::Level::TRACE) { @@ -330,14 +336,17 @@ where tracing::trace!(target: "wire", %hex_bytes, "receiving bytes"); } - let Some(recipient) = self.clients_by_allocation.get(&allocation) else { + let Some(client) = self.clients_by_allocation.get(&allocation).copied() else { tracing::debug!(target: "relay", "unknown allocation"); return; }; - Span::current().record("recipient", field::display(&recipient)); + Span::current().record("recipient", field::display(&client)); - let Some(channel_number) = self.channel_numbers_by_peer.get(&sender) else { + let Some(channel_number) = self + .channel_numbers_by_client_and_peer + .get(&(client, sender)) + else { tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", bytes.len()); return; }; @@ -346,7 +355,7 @@ where let Some(channel) = self .channels_by_client_and_number - .get(&(*recipient, *channel_number)) + .get(&(client, *channel_number)) else { debug_assert!(false, "unknown channel {}", channel_number); return; @@ -376,7 +385,7 @@ where self.pending_commands.push_back(Command::SendMessage { payload: data, - recipient: *recipient, + recipient: client, }) } @@ -438,13 +447,13 @@ where self.pending_commands.pop_front() } - fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) { + fn handle_binding_request(&mut self, message: Binding, sender: ClientSocket) { let mut message = Message::new( MessageClass::SuccessResponse, BINDING, message.transaction_id(), ); - message.add_attribute(XorMappedAddress::new(sender)); + message.add_attribute(XorMappedAddress::new(sender.0)); self.send_message(message, sender); } @@ -455,7 +464,7 @@ where fn handle_allocate_request( &mut self, request: Allocate, - sender: SocketAddr, + sender: ClientSocket, now: SystemTime, ) -> Result<(), Message> { self.verify_auth(&request, now)?; @@ -518,7 +527,7 @@ where ))); } - message.add_attribute(XorMappedAddress::new(sender)); + message.add_attribute(XorMappedAddress::new(sender.0)); message.add_attribute(effective_lifetime.clone()); let wake_deadline = self.time_events.add( @@ -574,7 +583,7 @@ where fn handle_refresh_request( &mut self, request: Refresh, - sender: SocketAddr, + sender: ClientSocket, now: SystemTime, ) -> Result<(), Message> { self.verify_auth(&request, now)?; @@ -630,7 +639,7 @@ where fn handle_channel_bind_request( &mut self, request: ChannelBind, - sender: SocketAddr, + sender: ClientSocket, now: SystemTime, ) -> Result<(), Message> { self.verify_auth(&request, now)?; @@ -642,7 +651,7 @@ where // Note: `channel_number` is enforced to be in the correct range. let requested_channel = request.channel_number().value(); - let peer_address = request.xor_peer_address().address(); + let peer_address = PeerSocket(request.xor_peer_address().address()); Span::current().record("allocation", display(&allocation.id)); Span::current().record("peer", display(&peer_address)); @@ -656,7 +665,10 @@ where } // Ensure the same address isn't already bound to a different channel. - if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) { + if let Some(number) = self + .channel_numbers_by_client_and_peer + .get(&(sender, peer_address)) + { if number != &requested_channel { tracing::warn!(target: "relay", existing_channel = %number, "Peer is already bound to another channel"); @@ -720,7 +732,7 @@ where fn handle_create_permission_request( &mut self, message: CreatePermission, - sender: SocketAddr, + sender: ClientSocket, now: SystemTime, ) -> Result<(), Message> { self.verify_auth(&message, now)?; @@ -736,7 +748,7 @@ where fn handle_channel_data_message( &mut self, message: ChannelData, - sender: SocketAddr, + sender: ClientSocket, _: SystemTime, ) { let channel_number = message.channel(); @@ -852,26 +864,31 @@ where fn create_channel_binding( &mut self, - client: SocketAddr, + client: ClientSocket, requested_channel: u16, - peer_address: SocketAddr, + peer: PeerSocket, id: AllocationId, now: SystemTime, ) { - self.channels_by_client_and_number.insert( + let existing = self.channels_by_client_and_number.insert( (client, requested_channel), Channel { expiry: now + CHANNEL_BINDING_DURATION, - peer_address, + peer_address: peer, allocation: id, bound: true, }, ); - self.channel_numbers_by_peer - .insert(peer_address, requested_channel); + debug_assert!(existing.is_none()); + + let existing = self + .channel_numbers_by_client_and_peer + .insert((client, peer), requested_channel); + + debug_assert!(existing.is_none()); } - fn send_message(&mut self, message: Message, recipient: SocketAddr) { + fn send_message(&mut self, message: Message, recipient: ClientSocket) { let method = message.method(); let class = message.class(); tracing::trace!(target: "relay", method = %message.method(), class = %message.class(), "Sending message"); @@ -950,14 +967,22 @@ where tracing::info!(target: "relay", %port, "Deleted allocation"); } - fn delete_channel_binding(&mut self, client: SocketAddr, chan: u16) { + fn delete_channel_binding(&mut self, client: ClientSocket, chan: u16) { let Some(channel) = self.channels_by_client_and_number.get(&(client, chan)) else { return; }; let addr = channel.peer_address; - self.channel_numbers_by_peer.remove(&addr); + let _peer_channel = self + .channel_numbers_by_client_and_peer + .remove(&(client, addr)); + debug_assert_eq!( + _peer_channel, + Some(chan), + "Channel state should be consistent" + ); + self.channels_by_client_and_number.remove(&(client, chan)); } } @@ -999,7 +1024,7 @@ struct Channel { expiry: SystemTime, /// The address of the peer that the channel is bound to. - peer_address: SocketAddr, + peer_address: PeerSocket, /// The allocation this channel belongs to. allocation: AllocationId, @@ -1031,8 +1056,8 @@ impl Allocation { /// /// This is called in the context of a channel binding with the requested peer address. /// We can only relay to the address if the allocation supports the same version of the IP protocol. - fn can_relay_to(&self, addr: SocketAddr) -> bool { - match addr { + fn can_relay_to(&self, addr: PeerSocket) -> bool { + match addr.0 { SocketAddr::V4(_) => self.first_relay_addr.is_ipv4(), // If we have an IPv4 address, it is in `first_relay_addr`, no need to check `second_relay_addr`. SocketAddr::V6(_) => { self.first_relay_addr.is_ipv6() @@ -1051,8 +1076,8 @@ impl Allocation { #[derive(PartialEq)] enum TimedAction { ExpireAllocation(AllocationId), - UnbindChannel((SocketAddr, u16)), - DeleteChannel((SocketAddr, u16)), + UnbindChannel((ClientSocket, u16)), + DeleteChannel((ClientSocket, u16)), } fn error_response( diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index 4aa2d58e5..fbb930114 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -1,7 +1,7 @@ use bytecodec::{DecodeExt, EncodeExt}; use firezone_relay::{ AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, - ClientMessage, Command, IpStack, Refresh, Server, + ClientMessage, ClientSocket, Command, IpStack, PeerSocket, Refresh, Server, }; use rand::rngs::mock::StepRng; use secrecy::SecretString; @@ -686,8 +686,8 @@ fn parse_message(message: &[u8]) -> Message { } enum Input<'a> { - Client(SocketAddr, ClientMessage<'a>, SystemTime), - Peer(SocketAddr, Vec, u16), + Client(ClientSocket, ClientMessage<'a>, SystemTime), + Peer(PeerSocket, Vec, u16), Time(SystemTime), } @@ -696,11 +696,11 @@ fn from_client<'a>( message: impl Into>, now: SystemTime, ) -> Input<'a> { - Input::Client(from.into(), message.into(), now) + Input::Client(ClientSocket::new(from.into()), message.into(), now) } fn from_peer<'a>(from: impl Into, data: &[u8], port: u16) -> Input<'a> { - Input::Peer(from.into(), data.to_vec(), port) + Input::Peer(PeerSocket::new(from.into()), data.to_vec(), port) } fn forward_time_to<'a>(when: SystemTime) -> Input<'a> { @@ -709,22 +709,22 @@ fn forward_time_to<'a>(when: SystemTime) -> Input<'a> { #[derive(Debug)] enum Output<'a> { - SendMessage((SocketAddr, Message)), - SendChannelData((SocketAddr, ChannelData<'a>)), - Forward((SocketAddr, Vec, u16)), + SendMessage((ClientSocket, Message)), + SendChannelData((ClientSocket, ChannelData<'a>)), + Forward((PeerSocket, Vec, u16)), Wake(SystemTime), CreateAllocation(u16, AddressFamily), FreeAllocation(u16, AddressFamily), } fn send_message<'a>(source: impl Into, message: Message) -> Output<'a> { - Output::SendMessage((source.into(), message)) + Output::SendMessage((ClientSocket::new(source.into()), message)) } fn send_channel_data(source: impl Into, message: ChannelData) -> Output { - Output::SendChannelData((source.into(), message)) + Output::SendChannelData((ClientSocket::new(source.into()), message)) } fn forward(source: impl Into, data: &[u8], port: u16) -> Output { - Output::Forward((source.into(), data.to_vec(), port)) + Output::Forward((PeerSocket::new(source.into()), data.to_vec(), port)) }