mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
fix(relay): correctly separate channel state for different peers (#3472)
Currently, there is a bug in the relay where the channel state of different peers overlaps because the data isn't indexed correctly by both peers and clients. This PR fixes this, introduces more debug assertions (this bug was caught by one) and also adds some new-type wrappers to avoid conflating peers with clients.
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
use crate::server::AllocationId;
|
||||
use crate::udp_socket::UdpSocket;
|
||||
use crate::AddressFamily;
|
||||
use crate::{AddressFamily, PeerSocket};
|
||||
use anyhow::{bail, Result};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::task;
|
||||
|
||||
/// The maximum amount of items that can be buffered in the channel to the allocation task.
|
||||
@@ -18,12 +17,12 @@ pub struct Allocation {
|
||||
///
|
||||
/// Stored here to make resource-cleanup easy.
|
||||
handle: task::JoinHandle<()>,
|
||||
sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
sender: mpsc::Sender<(Vec<u8>, PeerSocket)>,
|
||||
}
|
||||
|
||||
impl Allocation {
|
||||
pub fn new(
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
|
||||
id: AllocationId,
|
||||
family: AddressFamily,
|
||||
port: u16,
|
||||
@@ -62,7 +61,7 @@ impl Allocation {
|
||||
///
|
||||
/// All our data is relayed over UDP which by design is an unreliable protocol.
|
||||
/// Thus, any application running on top of this relay must already account for potential packet loss.
|
||||
pub fn send(&mut self, data: Vec<u8>, recipient: SocketAddr) -> Result<()> {
|
||||
pub fn send(&mut self, data: Vec<u8>, recipient: PeerSocket) -> Result<()> {
|
||||
match self.sender.try_send((data, recipient)) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(e) if e.is_disconnected() => {
|
||||
@@ -89,8 +88,8 @@ impl Drop for Allocation {
|
||||
}
|
||||
|
||||
async fn forward_incoming_relay_data(
|
||||
mut relayed_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
mut client_to_peer_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
|
||||
mut relayed_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
|
||||
mut client_to_peer_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket)>,
|
||||
id: AllocationId,
|
||||
family: AddressFamily,
|
||||
port: u16,
|
||||
@@ -101,11 +100,11 @@ async fn forward_incoming_relay_data(
|
||||
tokio::select! {
|
||||
result = socket.recv() => {
|
||||
let (data, sender) = result?;
|
||||
relayed_data_sender.send((data.to_vec(), sender, id)).await?;
|
||||
relayed_data_sender.send((data.to_vec(), PeerSocket::new(sender), id)).await?;
|
||||
}
|
||||
|
||||
Some((data, recipient)) = client_to_peer_receiver.next() => {
|
||||
socket.send_to(&data, recipient).await?;
|
||||
socket.send_to(&data, recipient.into_socket()).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ pub mod health_check;
|
||||
pub mod proptest;
|
||||
|
||||
pub use allocation::Allocation;
|
||||
pub use net_ext::{IpAddrExt, SocketAddrExt};
|
||||
pub use net_ext::IpAddrExt;
|
||||
pub use server::{
|
||||
Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData, ClientMessage, Command,
|
||||
CreatePermission, Refresh, Server,
|
||||
@@ -22,7 +22,10 @@ pub use udp_socket::UdpSocket;
|
||||
|
||||
pub(crate) use time_events::TimeEvents;
|
||||
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::{
|
||||
fmt,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
};
|
||||
|
||||
/// Describes the IP stack of a relay server.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -67,3 +70,65 @@ impl From<(Ipv4Addr, Ipv6Addr)> for IpStack {
|
||||
IpStack::Dual { ip4, ip6 }
|
||||
}
|
||||
}
|
||||
|
||||
/// New-type for a client's socket.
|
||||
///
|
||||
/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.4):
|
||||
///
|
||||
/// > A STUN client that implements this specification.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub struct ClientSocket(SocketAddr);
|
||||
|
||||
impl ClientSocket {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self(addr)
|
||||
}
|
||||
|
||||
pub fn into_socket(self) -> SocketAddr {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn family(&self) -> AddressFamily {
|
||||
match self.0 {
|
||||
SocketAddr::V4(_) => AddressFamily::V4,
|
||||
SocketAddr::V6(_) => AddressFamily::V6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ClientSocket {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// New-type for a peer's socket.
|
||||
///
|
||||
/// From the [spec](https://www.rfc-editor.org/rfc/rfc8656#section-2-4.8):
|
||||
///
|
||||
/// > A host with which the TURN client wishes to communicate. The TURN server relays traffic between the TURN client and its peer(s). The peer does not interact with the TURN server using the protocol defined in this document; rather, the peer receives data sent by the TURN server, and the peer sends data towards the TURN server.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub struct PeerSocket(SocketAddr);
|
||||
|
||||
impl PeerSocket {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self(addr)
|
||||
}
|
||||
|
||||
pub fn family(&self) -> AddressFamily {
|
||||
match self.0 {
|
||||
SocketAddr::V4(_) => AddressFamily::V4,
|
||||
SocketAddr::V6(_) => AddressFamily::V6,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_socket(self) -> SocketAddr {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PeerSocket {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ use anyhow::{anyhow, bail, Context, Result};
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use clap::Parser;
|
||||
use firezone_relay::{
|
||||
AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt,
|
||||
UdpSocket,
|
||||
AddressFamily, Allocation, AllocationId, ClientSocket, Command, IpStack, PeerSocket, Server,
|
||||
Sleep, UdpSocket,
|
||||
};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{future, FutureExt, SinkExt, StreamExt};
|
||||
@@ -313,14 +313,14 @@ fn make_rng(seed: Option<u64>) -> StdRng {
|
||||
}
|
||||
|
||||
struct Eventloop<R> {
|
||||
inbound_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
|
||||
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
inbound_data_receiver: mpsc::Receiver<(Vec<u8>, ClientSocket)>,
|
||||
outbound_ip4_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
|
||||
outbound_ip6_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
|
||||
server: Server<R>,
|
||||
channel: Option<PhoenixChannel<JoinMessage, (), ()>>,
|
||||
allocations: HashMap<(AllocationId, AddressFamily), Allocation>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr, AllocationId)>,
|
||||
relay_data_sender: mpsc::Sender<(Vec<u8>, PeerSocket, AllocationId)>,
|
||||
relay_data_receiver: mpsc::Receiver<(Vec<u8>, PeerSocket, AllocationId)>,
|
||||
sleep: Sleep,
|
||||
|
||||
stats_log_interval: tokio::time::Interval,
|
||||
@@ -338,10 +338,8 @@ where
|
||||
) -> Result<Self> {
|
||||
let (relay_data_sender, relay_data_receiver) = mpsc::channel(1);
|
||||
let (inbound_data_sender, inbound_data_receiver) = mpsc::channel(1000);
|
||||
let (outbound_ip4_data_sender, outbound_ip4_data_receiver) =
|
||||
mpsc::channel::<(Vec<u8>, SocketAddr)>(1000);
|
||||
let (outbound_ip6_data_sender, outbound_ip6_data_receiver) =
|
||||
mpsc::channel::<(Vec<u8>, SocketAddr)>(1000);
|
||||
let (outbound_ip4_data_sender, outbound_ip4_data_receiver) = mpsc::channel(1000);
|
||||
let (outbound_ip6_data_sender, outbound_ip6_data_receiver) = mpsc::channel(1000);
|
||||
|
||||
if public_address.as_v4().is_some() {
|
||||
tokio::spawn(main_udp_socket_task(
|
||||
@@ -563,8 +561,8 @@ fn fmt_human_throughput(mut throughput: f64) -> String {
|
||||
|
||||
async fn main_udp_socket_task(
|
||||
family: AddressFamily,
|
||||
mut inbound_data_sender: mpsc::Sender<(Vec<u8>, SocketAddr)>,
|
||||
mut outbound_data_receiver: mpsc::Receiver<(Vec<u8>, SocketAddr)>,
|
||||
mut inbound_data_sender: mpsc::Sender<(Vec<u8>, ClientSocket)>,
|
||||
mut outbound_data_receiver: mpsc::Receiver<(Vec<u8>, ClientSocket)>,
|
||||
) -> Result<Infallible> {
|
||||
let mut socket = UdpSocket::bind(family, 3478)?;
|
||||
|
||||
@@ -572,11 +570,11 @@ async fn main_udp_socket_task(
|
||||
tokio::select! {
|
||||
result = socket.recv() => {
|
||||
let (data, sender) = result?;
|
||||
inbound_data_sender.send((data.to_vec(), sender)).await?;
|
||||
inbound_data_sender.send((data.to_vec(), ClientSocket::new(sender))).await?;
|
||||
}
|
||||
maybe_item = outbound_data_receiver.next() => {
|
||||
let (data, recipient) = maybe_item.context("Outbound data channel closed")?;
|
||||
socket.send_to(data.as_ref(), recipient).await?;
|
||||
socket.send_to(data.as_ref(), recipient.into_socket()).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::IpAddr;
|
||||
use stun_codec::rfc8656::attributes::AddressFamily;
|
||||
|
||||
pub trait IpAddrExt {
|
||||
@@ -13,16 +13,3 @@ impl IpAddrExt for IpAddr {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SocketAddrExt {
|
||||
fn family(&self) -> AddressFamily;
|
||||
}
|
||||
|
||||
impl SocketAddrExt for SocketAddr {
|
||||
fn family(&self) -> AddressFamily {
|
||||
match self {
|
||||
SocketAddr::V4(_) => AddressFamily::V4,
|
||||
SocketAddr::V6(_) => AddressFamily::V6,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ pub use crate::server::client_message::{
|
||||
|
||||
use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE};
|
||||
use crate::net_ext::IpAddrExt;
|
||||
use crate::{IpStack, TimeEvents};
|
||||
use crate::{ClientSocket, IpStack, PeerSocket, TimeEvents};
|
||||
use anyhow::Result;
|
||||
use bytecodec::EncodeExt;
|
||||
use core::fmt;
|
||||
@@ -53,15 +53,17 @@ pub struct Server<R> {
|
||||
public_address: IpStack,
|
||||
|
||||
/// All client allocations, indexed by client's socket address.
|
||||
allocations: HashMap<SocketAddr, Allocation>,
|
||||
clients_by_allocation: HashMap<AllocationId, SocketAddr>,
|
||||
allocations: HashMap<ClientSocket, Allocation>,
|
||||
clients_by_allocation: HashMap<AllocationId, ClientSocket>,
|
||||
allocations_by_port: HashMap<u16, AllocationId>,
|
||||
|
||||
lowest_port: u16,
|
||||
highest_port: u16,
|
||||
|
||||
channels_by_client_and_number: HashMap<(SocketAddr, u16), Channel>,
|
||||
channel_numbers_by_peer: HashMap<SocketAddr, u16>,
|
||||
/// Channel numbers are unique by client, thus indexed by both.
|
||||
channels_by_client_and_number: HashMap<(ClientSocket, u16), Channel>,
|
||||
/// Channel numbers are unique between clients and peers, thus indexed by both.
|
||||
channel_numbers_by_client_and_peer: HashMap<(ClientSocket, PeerSocket), u16>,
|
||||
|
||||
pending_commands: VecDeque<Command>,
|
||||
next_allocation_id: AllocationId,
|
||||
@@ -87,7 +89,7 @@ pub struct Server<R> {
|
||||
pub enum Command {
|
||||
SendMessage {
|
||||
payload: Vec<u8>,
|
||||
recipient: SocketAddr,
|
||||
recipient: ClientSocket,
|
||||
},
|
||||
/// Listen for traffic on the provided port [AddressFamily].
|
||||
///
|
||||
@@ -108,7 +110,7 @@ pub enum Command {
|
||||
ForwardData {
|
||||
id: AllocationId,
|
||||
data: Vec<u8>,
|
||||
receiver: SocketAddr,
|
||||
receiver: PeerSocket,
|
||||
},
|
||||
/// At the latest, the [`Server`] needs to be woken at the specified deadline to execute time-based actions correctly.
|
||||
Wake { deadline: SystemTime },
|
||||
@@ -187,7 +189,7 @@ where
|
||||
lowest_port,
|
||||
highest_port,
|
||||
channels_by_client_and_number: Default::default(),
|
||||
channel_numbers_by_peer: Default::default(),
|
||||
channel_numbers_by_client_and_peer: Default::default(),
|
||||
pending_commands: Default::default(),
|
||||
next_allocation_id: AllocationId(1),
|
||||
auth_secret: SecretString::from(hex::encode(rng.gen::<[u8; 32]>())),
|
||||
@@ -228,7 +230,7 @@ where
|
||||
///
|
||||
/// After calling this method, you should call [`Server::next_command`] until it returns `None`.
|
||||
#[tracing::instrument(skip_all, fields(transaction_id, %sender, allocation, channel, recipient, peer), level = "error")]
|
||||
pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: SystemTime) {
|
||||
pub fn handle_client_input(&mut self, bytes: &[u8], sender: ClientSocket, now: SystemTime) {
|
||||
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
|
||||
let hex_bytes = hex::encode(bytes);
|
||||
tracing::trace!(target: "wire", %hex_bytes, "receiving bytes");
|
||||
@@ -265,7 +267,7 @@ where
|
||||
pub fn handle_client_message(
|
||||
&mut self,
|
||||
message: ClientMessage,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
now: SystemTime,
|
||||
) {
|
||||
let result = match message {
|
||||
@@ -294,7 +296,11 @@ where
|
||||
self.queue_error_response(sender, error_response)
|
||||
}
|
||||
|
||||
fn queue_error_response(&mut self, sender: SocketAddr, mut error_response: Message<Attribute>) {
|
||||
fn queue_error_response(
|
||||
&mut self,
|
||||
sender: ClientSocket,
|
||||
mut error_response: Message<Attribute>,
|
||||
) {
|
||||
let Some(error) = error_response.get_attribute::<ErrorCode>().cloned() else {
|
||||
debug_assert!(false, "Error response without an `ErrorCode`");
|
||||
return;
|
||||
@@ -322,7 +328,7 @@ where
|
||||
pub fn handle_peer_traffic(
|
||||
&mut self,
|
||||
bytes: &[u8],
|
||||
sender: SocketAddr,
|
||||
sender: PeerSocket,
|
||||
allocation: AllocationId,
|
||||
) {
|
||||
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
|
||||
@@ -330,14 +336,17 @@ where
|
||||
tracing::trace!(target: "wire", %hex_bytes, "receiving bytes");
|
||||
}
|
||||
|
||||
let Some(recipient) = self.clients_by_allocation.get(&allocation) else {
|
||||
let Some(client) = self.clients_by_allocation.get(&allocation).copied() else {
|
||||
tracing::debug!(target: "relay", "unknown allocation");
|
||||
return;
|
||||
};
|
||||
|
||||
Span::current().record("recipient", field::display(&recipient));
|
||||
Span::current().record("recipient", field::display(&client));
|
||||
|
||||
let Some(channel_number) = self.channel_numbers_by_peer.get(&sender) else {
|
||||
let Some(channel_number) = self
|
||||
.channel_numbers_by_client_and_peer
|
||||
.get(&(client, sender))
|
||||
else {
|
||||
tracing::debug!(target: "relay", "no active channel, refusing to relay {} bytes", bytes.len());
|
||||
return;
|
||||
};
|
||||
@@ -346,7 +355,7 @@ where
|
||||
|
||||
let Some(channel) = self
|
||||
.channels_by_client_and_number
|
||||
.get(&(*recipient, *channel_number))
|
||||
.get(&(client, *channel_number))
|
||||
else {
|
||||
debug_assert!(false, "unknown channel {}", channel_number);
|
||||
return;
|
||||
@@ -376,7 +385,7 @@ where
|
||||
|
||||
self.pending_commands.push_back(Command::SendMessage {
|
||||
payload: data,
|
||||
recipient: *recipient,
|
||||
recipient: client,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -438,13 +447,13 @@ where
|
||||
self.pending_commands.pop_front()
|
||||
}
|
||||
|
||||
fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) {
|
||||
fn handle_binding_request(&mut self, message: Binding, sender: ClientSocket) {
|
||||
let mut message = Message::new(
|
||||
MessageClass::SuccessResponse,
|
||||
BINDING,
|
||||
message.transaction_id(),
|
||||
);
|
||||
message.add_attribute(XorMappedAddress::new(sender));
|
||||
message.add_attribute(XorMappedAddress::new(sender.0));
|
||||
|
||||
self.send_message(message, sender);
|
||||
}
|
||||
@@ -455,7 +464,7 @@ where
|
||||
fn handle_allocate_request(
|
||||
&mut self,
|
||||
request: Allocate,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
now: SystemTime,
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
self.verify_auth(&request, now)?;
|
||||
@@ -518,7 +527,7 @@ where
|
||||
)));
|
||||
}
|
||||
|
||||
message.add_attribute(XorMappedAddress::new(sender));
|
||||
message.add_attribute(XorMappedAddress::new(sender.0));
|
||||
message.add_attribute(effective_lifetime.clone());
|
||||
|
||||
let wake_deadline = self.time_events.add(
|
||||
@@ -574,7 +583,7 @@ where
|
||||
fn handle_refresh_request(
|
||||
&mut self,
|
||||
request: Refresh,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
now: SystemTime,
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
self.verify_auth(&request, now)?;
|
||||
@@ -630,7 +639,7 @@ where
|
||||
fn handle_channel_bind_request(
|
||||
&mut self,
|
||||
request: ChannelBind,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
now: SystemTime,
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
self.verify_auth(&request, now)?;
|
||||
@@ -642,7 +651,7 @@ where
|
||||
|
||||
// Note: `channel_number` is enforced to be in the correct range.
|
||||
let requested_channel = request.channel_number().value();
|
||||
let peer_address = request.xor_peer_address().address();
|
||||
let peer_address = PeerSocket(request.xor_peer_address().address());
|
||||
|
||||
Span::current().record("allocation", display(&allocation.id));
|
||||
Span::current().record("peer", display(&peer_address));
|
||||
@@ -656,7 +665,10 @@ where
|
||||
}
|
||||
|
||||
// Ensure the same address isn't already bound to a different channel.
|
||||
if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) {
|
||||
if let Some(number) = self
|
||||
.channel_numbers_by_client_and_peer
|
||||
.get(&(sender, peer_address))
|
||||
{
|
||||
if number != &requested_channel {
|
||||
tracing::warn!(target: "relay", existing_channel = %number, "Peer is already bound to another channel");
|
||||
|
||||
@@ -720,7 +732,7 @@ where
|
||||
fn handle_create_permission_request(
|
||||
&mut self,
|
||||
message: CreatePermission,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
now: SystemTime,
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
self.verify_auth(&message, now)?;
|
||||
@@ -736,7 +748,7 @@ where
|
||||
fn handle_channel_data_message(
|
||||
&mut self,
|
||||
message: ChannelData,
|
||||
sender: SocketAddr,
|
||||
sender: ClientSocket,
|
||||
_: SystemTime,
|
||||
) {
|
||||
let channel_number = message.channel();
|
||||
@@ -852,26 +864,31 @@ where
|
||||
|
||||
fn create_channel_binding(
|
||||
&mut self,
|
||||
client: SocketAddr,
|
||||
client: ClientSocket,
|
||||
requested_channel: u16,
|
||||
peer_address: SocketAddr,
|
||||
peer: PeerSocket,
|
||||
id: AllocationId,
|
||||
now: SystemTime,
|
||||
) {
|
||||
self.channels_by_client_and_number.insert(
|
||||
let existing = self.channels_by_client_and_number.insert(
|
||||
(client, requested_channel),
|
||||
Channel {
|
||||
expiry: now + CHANNEL_BINDING_DURATION,
|
||||
peer_address,
|
||||
peer_address: peer,
|
||||
allocation: id,
|
||||
bound: true,
|
||||
},
|
||||
);
|
||||
self.channel_numbers_by_peer
|
||||
.insert(peer_address, requested_channel);
|
||||
debug_assert!(existing.is_none());
|
||||
|
||||
let existing = self
|
||||
.channel_numbers_by_client_and_peer
|
||||
.insert((client, peer), requested_channel);
|
||||
|
||||
debug_assert!(existing.is_none());
|
||||
}
|
||||
|
||||
fn send_message(&mut self, message: Message<Attribute>, recipient: SocketAddr) {
|
||||
fn send_message(&mut self, message: Message<Attribute>, recipient: ClientSocket) {
|
||||
let method = message.method();
|
||||
let class = message.class();
|
||||
tracing::trace!(target: "relay", method = %message.method(), class = %message.class(), "Sending message");
|
||||
@@ -950,14 +967,22 @@ where
|
||||
tracing::info!(target: "relay", %port, "Deleted allocation");
|
||||
}
|
||||
|
||||
fn delete_channel_binding(&mut self, client: SocketAddr, chan: u16) {
|
||||
fn delete_channel_binding(&mut self, client: ClientSocket, chan: u16) {
|
||||
let Some(channel) = self.channels_by_client_and_number.get(&(client, chan)) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let addr = channel.peer_address;
|
||||
|
||||
self.channel_numbers_by_peer.remove(&addr);
|
||||
let _peer_channel = self
|
||||
.channel_numbers_by_client_and_peer
|
||||
.remove(&(client, addr));
|
||||
debug_assert_eq!(
|
||||
_peer_channel,
|
||||
Some(chan),
|
||||
"Channel state should be consistent"
|
||||
);
|
||||
|
||||
self.channels_by_client_and_number.remove(&(client, chan));
|
||||
}
|
||||
}
|
||||
@@ -999,7 +1024,7 @@ struct Channel {
|
||||
expiry: SystemTime,
|
||||
|
||||
/// The address of the peer that the channel is bound to.
|
||||
peer_address: SocketAddr,
|
||||
peer_address: PeerSocket,
|
||||
|
||||
/// The allocation this channel belongs to.
|
||||
allocation: AllocationId,
|
||||
@@ -1031,8 +1056,8 @@ impl Allocation {
|
||||
///
|
||||
/// This is called in the context of a channel binding with the requested peer address.
|
||||
/// We can only relay to the address if the allocation supports the same version of the IP protocol.
|
||||
fn can_relay_to(&self, addr: SocketAddr) -> bool {
|
||||
match addr {
|
||||
fn can_relay_to(&self, addr: PeerSocket) -> bool {
|
||||
match addr.0 {
|
||||
SocketAddr::V4(_) => self.first_relay_addr.is_ipv4(), // If we have an IPv4 address, it is in `first_relay_addr`, no need to check `second_relay_addr`.
|
||||
SocketAddr::V6(_) => {
|
||||
self.first_relay_addr.is_ipv6()
|
||||
@@ -1051,8 +1076,8 @@ impl Allocation {
|
||||
#[derive(PartialEq)]
|
||||
enum TimedAction {
|
||||
ExpireAllocation(AllocationId),
|
||||
UnbindChannel((SocketAddr, u16)),
|
||||
DeleteChannel((SocketAddr, u16)),
|
||||
UnbindChannel((ClientSocket, u16)),
|
||||
DeleteChannel((ClientSocket, u16)),
|
||||
}
|
||||
|
||||
fn error_response(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use bytecodec::{DecodeExt, EncodeExt};
|
||||
use firezone_relay::{
|
||||
AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData,
|
||||
ClientMessage, Command, IpStack, Refresh, Server,
|
||||
ClientMessage, ClientSocket, Command, IpStack, PeerSocket, Refresh, Server,
|
||||
};
|
||||
use rand::rngs::mock::StepRng;
|
||||
use secrecy::SecretString;
|
||||
@@ -686,8 +686,8 @@ fn parse_message(message: &[u8]) -> Message<Attribute> {
|
||||
}
|
||||
|
||||
enum Input<'a> {
|
||||
Client(SocketAddr, ClientMessage<'a>, SystemTime),
|
||||
Peer(SocketAddr, Vec<u8>, u16),
|
||||
Client(ClientSocket, ClientMessage<'a>, SystemTime),
|
||||
Peer(PeerSocket, Vec<u8>, u16),
|
||||
Time(SystemTime),
|
||||
}
|
||||
|
||||
@@ -696,11 +696,11 @@ fn from_client<'a>(
|
||||
message: impl Into<ClientMessage<'a>>,
|
||||
now: SystemTime,
|
||||
) -> Input<'a> {
|
||||
Input::Client(from.into(), message.into(), now)
|
||||
Input::Client(ClientSocket::new(from.into()), message.into(), now)
|
||||
}
|
||||
|
||||
fn from_peer<'a>(from: impl Into<SocketAddr>, data: &[u8], port: u16) -> Input<'a> {
|
||||
Input::Peer(from.into(), data.to_vec(), port)
|
||||
Input::Peer(PeerSocket::new(from.into()), data.to_vec(), port)
|
||||
}
|
||||
|
||||
fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
|
||||
@@ -709,22 +709,22 @@ fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Output<'a> {
|
||||
SendMessage((SocketAddr, Message<Attribute>)),
|
||||
SendChannelData((SocketAddr, ChannelData<'a>)),
|
||||
Forward((SocketAddr, Vec<u8>, u16)),
|
||||
SendMessage((ClientSocket, Message<Attribute>)),
|
||||
SendChannelData((ClientSocket, ChannelData<'a>)),
|
||||
Forward((PeerSocket, Vec<u8>, u16)),
|
||||
Wake(SystemTime),
|
||||
CreateAllocation(u16, AddressFamily),
|
||||
FreeAllocation(u16, AddressFamily),
|
||||
}
|
||||
|
||||
fn send_message<'a>(source: impl Into<SocketAddr>, message: Message<Attribute>) -> Output<'a> {
|
||||
Output::SendMessage((source.into(), message))
|
||||
Output::SendMessage((ClientSocket::new(source.into()), message))
|
||||
}
|
||||
|
||||
fn send_channel_data(source: impl Into<SocketAddr>, message: ChannelData) -> Output {
|
||||
Output::SendChannelData((source.into(), message))
|
||||
Output::SendChannelData((ClientSocket::new(source.into()), message))
|
||||
}
|
||||
|
||||
fn forward(source: impl Into<SocketAddr>, data: &[u8], port: u16) -> Output {
|
||||
Output::Forward((source.into(), data.to_vec(), port))
|
||||
Output::Forward((PeerSocket::new(source.into()), data.to_vec(), port))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user