fix(relay): correctly separate channel state for different peers (#3472)

Currently, there is a bug in the relay where the channel state of
different peers overlaps because the data isn't indexed correctly by
both peers and clients.

This PR fixes this, introduces more debug assertions (this bug was
caught by one) and also adds some new-type wrappers to avoid conflating
peers with clients.
This commit is contained in:
Thomas Eizinger
2024-02-01 12:53:54 +11:00
committed by GitHub
parent a5a6d81eb1
commit 84b3ac50ca
6 changed files with 166 additions and 92 deletions

View File

@@ -1,11 +1,10 @@
use crate::server::AllocationId;
use crate::udp_socket::UdpSocket;
use crate::AddressFamily;
use crate::{AddressFamily, PeerSocket};
use anyhow::{bail, Result};
use futures::channel::mpsc;
use futures::{SinkExt, StreamExt};
use std::convert::Infallible;
use std::net::SocketAddr;
use tokio::task;
/// The maximum amount of items that can be buffered in the channel to the allocation task.
@@ -18,12 +17,12 @@ pub struct Allocation {
///
/// Stored here to make resource-cleanup easy.
handle: task::JoinHandle<()>,
sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
sender: mpsc::Sender<(Vec<u8>, PeerSocket)>,
}
impl Allocation {
pub fn new(
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
id: AllocationId,
family: AddressFamily,
port: u16,
@@ -62,7 +61,7 @@ impl Allocation {
///
/// All our data is relayed over UDP which by design is an unreliable protocol.
/// Thus, any application running on top of this relay must already account for potential packet loss.
pub fn send(&mut self, data: Vec<u8>, recipient: SocketAddr) -> Result<()> {
pub fn send(&mut self, data: Vec<u8>, recipient: PeerSocket) -> Result<()> {
match self.sender.try_send((data, recipient)) {
Ok(()) => Ok(()),
Err(e) if e.is_disconnected() => {
@@ -89,8 +88,8 @@ impl Drop for Allocation {
}
async fn forward_incoming_relay_data(
mut relayed_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
mut client_to_peer_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
mut relayed_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
mut client_to_peer_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket)>,
id: AllocationId,
family: AddressFamily,
port: u16,
@@ -101,11 +100,11 @@ async fn forward_incoming_relay_data(
tokio::select! {
result = socket.recv() => {
let (data, sender) = result?;
relayed_data_sender.send((data.to_vec(), sender, id)).await?;
relayed_data_sender.send((data.to_vec(), PeerSocket::new(sender), id)).await?;
}
Some((data, recipient)) = client_to_peer_receiver.next() => {
socket.send_to(&data, recipient).await?;
socket.send_to(&data, recipient.into_socket()).await?;
}
}
}

View File

@@ -11,7 +11,7 @@ pub mod health_check;
pub mod proptest;
pub use allocation::Allocation;
pub use net_ext::{IpAddrExt, SocketAddrExt};
pub use net_ext::IpAddrExt;
pub use server::{
Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command,
CreatePermission, Refresh, Server,
@@ -22,7 +22,10 @@ pub use udp_socket::UdpSocket;
pub(crate) use time_events::TimeEvents;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
};
/// Describes the IP stack of a relay server.
#[derive(Debug, Copy, Clone)]
@@ -67,3 +70,65 @@ impl From<(Ipv4Addr, Ipv6Addr)> for IpStack {
IpStack::Dual { ip4, ip6 }
}
}
/// New-type for a client's socket.
///
/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.4):
///
/// > A STUN client that implements this specification.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct ClientSocket(SocketAddr);
impl ClientSocket {
pub fn new(addr: SocketAddr) -> Self {
Self(addr)
}
pub fn into_socket(self) -> SocketAddr {
self.0
}
pub fn family(&self) -> AddressFamily {
match self.0 {
SocketAddr::V4(_) => AddressFamily::V4,
SocketAddr::V6(_) => AddressFamily::V6,
}
}
}
impl fmt::Display for ClientSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
/// New-type for a peer's socket.
///
/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.8):
///
/// > A host with which the TURN client wishes to communicate. The TURN server relays traffic between the TURN client and its peer(s). The peer does not interact with the TURN server using the protocol defined in this document; rather, the peer receives data sent by the TURN server, and the peer sends data towards the TURN server.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct PeerSocket(SocketAddr);
impl PeerSocket {
pub fn new(addr: SocketAddr) -> Self {
Self(addr)
}
pub fn family(&self) -> AddressFamily {
match self.0 {
SocketAddr::V4(_) => AddressFamily::V4,
SocketAddr::V6(_) => AddressFamily::V6,
}
}
pub fn into_socket(self) -> SocketAddr {
self.0
}
}
impl fmt::Display for PeerSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

View File

@@ -2,8 +2,8 @@ use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use firezone_relay::{
AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt,
UdpSocket,
AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, Server,
Sleep, UdpSocket,
};
use futures::channel::mpsc;
use futures::{future, FutureExt, SinkExt, StreamExt};
@@ -313,14 +313,14 @@ fn make_rng(seed: Option<u64>) -> StdRng {
}
struct Eventloop<R> {
inbound_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
inbound_data_receiver: mpsc::Receiver<(Vec<u8>, ClientSocket)>,
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
server: Server<R>,
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
relay_data_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket, AllocationId)>,
sleep: Sleep,
stats_log_interval: tokio::time::Interval,
@@ -338,10 +338,8 @@ where
) -> Result<Self> {
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);
let (inbound_data_sender, inbound_data_receiver) = mpsc::channel(1000);
let (outbound_ip4_data_sender, outbound_ip4_data_receiver) =
mpsc::channel::<(Vec<u8>, SocketAddr)>(1000);
let (outbound_ip6_data_sender, outbound_ip6_data_receiver) =
mpsc::channel::<(Vec<u8>, SocketAddr)>(1000);
let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = mpsc::channel(1000);
let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = mpsc::channel(1000);
if public_address.as_v4().is_some() {
tokio::spawn(main_udp_socket_task(
@@ -563,8 +561,8 @@ fn fmt_human_throughput(mut throughput: f64) -> String {
async fn main_udp_socket_task(
family: AddressFamily,
mut inbound_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
mut outbound_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
mut inbound_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
mut outbound_data_receiver: mpsc::Receiver<(Vec<u8>, ClientSocket)>,
) -> Result<Infallible> {
let mut socket = UdpSocket::bind(family, 3478)?;
@@ -572,11 +570,11 @@ async fn main_udp_socket_task(
tokio::select! {
result = socket.recv() => {
let (data, sender) = result?;
inbound_data_sender.send((data.to_vec(), sender)).await?;
inbound_data_sender.send((data.to_vec(), ClientSocket::new(sender))).await?;
}
maybe_item = outbound_data_receiver.next() => {
let (data, recipient) = maybe_item.context("Outbound data channel closed")?;
socket.send_to(data.as_ref(), recipient).await?;
socket.send_to(data.as_ref(), recipient.into_socket()).await?;
}
}
}

View File

@@ -1,4 +1,4 @@
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
use stun_codec::rfc8656::attributes::AddressFamily;
pub trait IpAddrExt {
@@ -13,16 +13,3 @@ impl IpAddrExt for IpAddr {
}
}
}
pub trait SocketAddrExt {
fn family(&self) -> AddressFamily;
}
impl SocketAddrExt for SocketAddr {
fn family(&self) -> AddressFamily {
match self {
SocketAddr::V4(_) => AddressFamily::V4,
SocketAddr::V6(_) => AddressFamily::V6,
}
}
}

View File

@@ -8,7 +8,7 @@ pub use crate::server::client_message::{
use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE};
use crate::net_ext::IpAddrExt;
use crate::{IpStack, TimeEvents};
use crate::{ClientSocket, IpStack, PeerSocket, TimeEvents};
use anyhow::Result;
use bytecodec::EncodeExt;
use core::fmt;
@@ -53,15 +53,17 @@ pub struct Server<R> {
public_address: IpStack,
/// All client allocations, indexed by client's socket address.
allocations: HashMap<SocketAddr, Allocation>,
clients_by_allocation: HashMap<AllocationId, SocketAddr>,
allocations: HashMap<ClientSocket, Allocation>,
clients_by_allocation: HashMap<AllocationId, ClientSocket>,
allocations_by_port: HashMap<u16, AllocationId>,
lowest_port: u16,
highest_port: u16,
channels_by_client_and_number: HashMap<(SocketAddr, u16), Channel>,
channel_numbers_by_peer: HashMap<SocketAddr, u16>,
/// Channel numbers are unique by client, thus indexed by both.
channels_by_client_and_number: HashMap<(ClientSocket, u16), Channel>,
/// Channel numbers are unique between clients and peers, thus indexed by both.
channel_numbers_by_client_and_peer: HashMap<(ClientSocket, PeerSocket), u16>,
pending_commands: VecDeque<Command>,
next_allocation_id: AllocationId,
@@ -87,7 +89,7 @@ pub struct Server<R> {
pub enum Command {
SendMessage {
payload: Vec<u8>,
recipient: SocketAddr,
recipient: ClientSocket,
},
/// Listen for traffic on the provided port [AddressFamily].
///
@@ -108,7 +110,7 @@ pub enum Command {
ForwardData {
id: AllocationId,
data: Vec<u8>,
receiver: SocketAddr,
receiver: PeerSocket,
},
/// At the latest, the [`Server`] needs to be woken at the specified deadline to execute time-based actions correctly.
Wake { deadline: SystemTime },
@@ -187,7 +189,7 @@ where
lowest_port,
highest_port,
channels_by_client_and_number: Default::default(),
channel_numbers_by_peer: Default::default(),
channel_numbers_by_client_and_peer: Default::default(),
pending_commands: Default::default(),
next_allocation_id: AllocationId(1),
auth_secret: SecretString::from(hex::encode(rng.gen::<[u8; 32]>())),
@@ -228,7 +230,7 @@ where
///
/// After calling this method, you should call [`Server::next_command`] until it returns `None`.
#[tracing::instrument(skip_all, fields(transaction_id, %sender, allocation, channel, recipient, peer), level = "error")]
pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: SystemTime) {
pub fn handle_client_input(&mut self, bytes: &[u8], sender: ClientSocket, now: SystemTime) {
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
let hex_bytes = hex::encode(bytes);
tracing::trace!(target: "wire", %hex_bytes, "receiving bytes");
@@ -265,7 +267,7 @@ where
pub fn handle_client_message(
&mut self,
message: ClientMessage,
sender: SocketAddr,
sender: ClientSocket,
now: SystemTime,
) {
let result = match message {
@@ -294,7 +296,11 @@ where
self.queue_error_response(sender, error_response)
}
fn queue_error_response(&mut self, sender: SocketAddr, mut error_response: Message<Attribute>) {
fn queue_error_response(
&mut self,
sender: ClientSocket,
mut error_response: Message<Attribute>,
) {
let Some(error) = error_response.get_attribute::<ErrorCode>().cloned() else {
debug_assert!(false, "Error response without an `ErrorCode`");
return;
@@ -322,7 +328,7 @@ where
pub fn handle_peer_traffic(
&mut self,
bytes: &[u8],
sender: SocketAddr,
sender: PeerSocket,
allocation: AllocationId,
) {
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
@@ -330,14 +336,17 @@ where
tracing::trace!(target: "wire", %hex_bytes, "receiving bytes");
}
let Some(recipient) = self.clients_by_allocation.get(&allocation) else {
let Some(client) = self.clients_by_allocation.get(&allocation).copied() else {
tracing::debug!(target: "relay", "unknown allocation");
return;
};
Span::current().record("recipient", field::display(&recipient));
Span::current().record("recipient", field::display(&client));
let Some(channel_number) = self.channel_numbers_by_peer.get(&sender) else {
let Some(channel_number) = self
.channel_numbers_by_client_and_peer
.get(&(client, sender))
else {
tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", bytes.len());
return;
};
@@ -346,7 +355,7 @@ where
let Some(channel) = self
.channels_by_client_and_number
.get(&(*recipient, *channel_number))
.get(&(client, *channel_number))
else {
debug_assert!(false, "unknown channel {}", channel_number);
return;
@@ -376,7 +385,7 @@ where
self.pending_commands.push_back(Command::SendMessage {
payload: data,
recipient: *recipient,
recipient: client,
})
}
@@ -438,13 +447,13 @@ where
self.pending_commands.pop_front()
}
fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) {
fn handle_binding_request(&mut self, message: Binding, sender: ClientSocket) {
let mut message = Message::new(
MessageClass::SuccessResponse,
BINDING,
message.transaction_id(),
);
message.add_attribute(XorMappedAddress::new(sender));
message.add_attribute(XorMappedAddress::new(sender.0));
self.send_message(message, sender);
}
@@ -455,7 +464,7 @@ where
fn handle_allocate_request(
&mut self,
request: Allocate,
sender: SocketAddr,
sender: ClientSocket,
now: SystemTime,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
@@ -518,7 +527,7 @@ where
)));
}
message.add_attribute(XorMappedAddress::new(sender));
message.add_attribute(XorMappedAddress::new(sender.0));
message.add_attribute(effective_lifetime.clone());
let wake_deadline = self.time_events.add(
@@ -574,7 +583,7 @@ where
fn handle_refresh_request(
&mut self,
request: Refresh,
sender: SocketAddr,
sender: ClientSocket,
now: SystemTime,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
@@ -630,7 +639,7 @@ where
fn handle_channel_bind_request(
&mut self,
request: ChannelBind,
sender: SocketAddr,
sender: ClientSocket,
now: SystemTime,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
@@ -642,7 +651,7 @@ where
// Note: `channel_number` is enforced to be in the correct range.
let requested_channel = request.channel_number().value();
let peer_address = request.xor_peer_address().address();
let peer_address = PeerSocket(request.xor_peer_address().address());
Span::current().record("allocation", display(&allocation.id));
Span::current().record("peer", display(&peer_address));
@@ -656,7 +665,10 @@ where
}
// Ensure the same address isn't already bound to a different channel.
if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) {
if let Some(number) = self
.channel_numbers_by_client_and_peer
.get(&(sender, peer_address))
{
if number != &requested_channel {
tracing::warn!(target: "relay", existing_channel = %number, "Peer is already bound to another channel");
@@ -720,7 +732,7 @@ where
fn handle_create_permission_request(
&mut self,
message: CreatePermission,
sender: SocketAddr,
sender: ClientSocket,
now: SystemTime,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&message, now)?;
@@ -736,7 +748,7 @@ where
fn handle_channel_data_message(
&mut self,
message: ChannelData,
sender: SocketAddr,
sender: ClientSocket,
_: SystemTime,
) {
let channel_number = message.channel();
@@ -852,26 +864,31 @@ where
fn create_channel_binding(
&mut self,
client: SocketAddr,
client: ClientSocket,
requested_channel: u16,
peer_address: SocketAddr,
peer: PeerSocket,
id: AllocationId,
now: SystemTime,
) {
self.channels_by_client_and_number.insert(
let existing = self.channels_by_client_and_number.insert(
(client, requested_channel),
Channel {
expiry: now + CHANNEL_BINDING_DURATION,
peer_address,
peer_address: peer,
allocation: id,
bound: true,
},
);
self.channel_numbers_by_peer
.insert(peer_address, requested_channel);
debug_assert!(existing.is_none());
let existing = self
.channel_numbers_by_client_and_peer
.insert((client, peer), requested_channel);
debug_assert!(existing.is_none());
}
fn send_message(&mut self, message: Message<Attribute>, recipient: SocketAddr) {
fn send_message(&mut self, message: Message<Attribute>, recipient: ClientSocket) {
let method = message.method();
let class = message.class();
tracing::trace!(target: "relay", method = %message.method(), class = %message.class(), "Sending message");
@@ -950,14 +967,22 @@ where
tracing::info!(target: "relay", %port, "Deleted allocation");
}
fn delete_channel_binding(&mut self, client: SocketAddr, chan: u16) {
fn delete_channel_binding(&mut self, client: ClientSocket, chan: u16) {
let Some(channel) = self.channels_by_client_and_number.get(&(client, chan)) else {
return;
};
let addr = channel.peer_address;
self.channel_numbers_by_peer.remove(&addr);
let _peer_channel = self
.channel_numbers_by_client_and_peer
.remove(&(client, addr));
debug_assert_eq!(
_peer_channel,
Some(chan),
"Channel state should be consistent"
);
self.channels_by_client_and_number.remove(&(client, chan));
}
}
@@ -999,7 +1024,7 @@ struct Channel {
expiry: SystemTime,
/// The address of the peer that the channel is bound to.
peer_address: SocketAddr,
peer_address: PeerSocket,
/// The allocation this channel belongs to.
allocation: AllocationId,
@@ -1031,8 +1056,8 @@ impl Allocation {
///
/// This is called in the context of a channel binding with the requested peer address.
/// We can only relay to the address if the allocation supports the same version of the IP protocol.
fn can_relay_to(&self, addr: SocketAddr) -> bool {
match addr {
fn can_relay_to(&self, addr: PeerSocket) -> bool {
match addr.0 {
SocketAddr::V4(_) => self.first_relay_addr.is_ipv4(), // If we have an IPv4 address, it is in `first_relay_addr`, no need to check `second_relay_addr`.
SocketAddr::V6(_) => {
self.first_relay_addr.is_ipv6()
@@ -1051,8 +1076,8 @@ impl Allocation {
#[derive(PartialEq)]
enum TimedAction {
ExpireAllocation(AllocationId),
UnbindChannel((SocketAddr, u16)),
DeleteChannel((SocketAddr, u16)),
UnbindChannel((ClientSocket, u16)),
DeleteChannel((ClientSocket, u16)),
}
fn error_response(

View File

@@ -1,7 +1,7 @@
use bytecodec::{DecodeExt, EncodeExt};
use firezone_relay::{
AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData,
ClientMessage, Command, IpStack, Refresh, Server,
ClientMessage, ClientSocket, Command, IpStack, PeerSocket, Refresh, Server,
};
use rand::rngs::mock::StepRng;
use secrecy::SecretString;
@@ -686,8 +686,8 @@ fn parse_message(message: &[u8]) -> Message<Attribute> {
}
enum Input<'a> {
Client(SocketAddr, ClientMessage<'a>, SystemTime),
Peer(SocketAddr, Vec<u8>, u16),
Client(ClientSocket, ClientMessage<'a>, SystemTime),
Peer(PeerSocket, Vec<u8>, u16),
Time(SystemTime),
}
@@ -696,11 +696,11 @@ fn from_client<'a>(
message: impl Into<ClientMessage<'a>>,
now: SystemTime,
) -> Input<'a> {
Input::Client(from.into(), message.into(), now)
Input::Client(ClientSocket::new(from.into()), message.into(), now)
}
fn from_peer<'a>(from: impl Into<SocketAddr>, data: &[u8], port: u16) -> Input<'a> {
Input::Peer(from.into(), data.to_vec(), port)
Input::Peer(PeerSocket::new(from.into()), data.to_vec(), port)
}
fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
@@ -709,22 +709,22 @@ fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
#[derive(Debug)]
enum Output<'a> {
SendMessage((SocketAddr, Message<Attribute>)),
SendChannelData((SocketAddr, ChannelData<'a>)),
Forward((SocketAddr, Vec<u8>, u16)),
SendMessage((ClientSocket, Message<Attribute>)),
SendChannelData((ClientSocket, ChannelData<'a>)),
Forward((PeerSocket, Vec<u8>, u16)),
Wake(SystemTime),
CreateAllocation(u16, AddressFamily),
FreeAllocation(u16, AddressFamily),
}
fn send_message<'a>(source: impl Into<SocketAddr>, message: Message<Attribute>) -> Output<'a> {
Output::SendMessage((source.into(), message))
Output::SendMessage((ClientSocket::new(source.into()), message))
}
fn send_channel_data(source: impl Into<SocketAddr>, message: ChannelData) -> Output {
Output::SendChannelData((source.into(), message))
Output::SendChannelData((ClientSocket::new(source.into()), message))
}
fn forward(source: impl Into<SocketAddr>, data: &[u8], port: u16) -> Output {
Output::Forward((source.into(), data.to_vec(), port))
Output::Forward((PeerSocket::new(source.into()), data.to_vec(), port))
}