mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(relay): introduce type-safe Server APIs (#1630)
We introduce dedicated types for each message that the `Server` can handle. This allows us to make the functions public because the type-system now guarantees that those are either parsed from bytes or constructed with the correct data. The latter will be useful to write tests against a richer API.
This commit is contained in:
7
rust/Cargo.lock
generated
7
rust/Cargo.lock
generated
@@ -603,6 +603,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "difference"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.9.0"
|
||||
@@ -1525,6 +1531,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"bytecodec",
|
||||
"bytes",
|
||||
"difference",
|
||||
"env_logger",
|
||||
"futures",
|
||||
"hex",
|
||||
|
||||
@@ -20,3 +20,4 @@ bytes = "1.4.0"
|
||||
[dev-dependencies]
|
||||
webrtc = "0.7.2"
|
||||
redis = { version = "0.23.0", default-features = false, features = ["tokio-comp"] }
|
||||
difference = "2.0.0"
|
||||
|
||||
@@ -190,7 +190,7 @@ impl Eventloop {
|
||||
// Priority 6: Accept new allocations / answer STUN requests etc
|
||||
if let Poll::Ready((buffer, sender)) = self.ip4_socket.poll_recv(cx)? {
|
||||
self.server
|
||||
.handle_client_input(buffer.filled(), sender, Instant::now())?;
|
||||
.handle_client_input(buffer.filled(), sender, Instant::now());
|
||||
continue; // Handle potentially new commands.
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
mod channel_data;
|
||||
mod client_message;
|
||||
|
||||
use crate::rfc8656::PeerAddressFamilyMismatch;
|
||||
use crate::server::channel_data::ChannelData;
|
||||
use crate::server::client_message::{
|
||||
Allocate, Binding, ChannelBind, ClientMessage, CreatePermission, Refresh,
|
||||
};
|
||||
use crate::stun_codec_ext::{MessageClassExt, MethodExt};
|
||||
use crate::TimeEvents;
|
||||
use anyhow::Result;
|
||||
use bytecodec::{DecodeExt, EncodeExt};
|
||||
use bytecodec::EncodeExt;
|
||||
use core::fmt;
|
||||
use rand::rngs::mock::StepRng;
|
||||
use rand::rngs::ThreadRng;
|
||||
@@ -23,7 +28,7 @@ use stun_codec::rfc5766::attributes::{
|
||||
};
|
||||
use stun_codec::rfc5766::errors::{AllocationMismatch, InsufficientCapacity};
|
||||
use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH};
|
||||
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method, TransactionId};
|
||||
use stun_codec::{Message, MessageClass, MessageEncoder, Method, TransactionId};
|
||||
|
||||
/// A sans-IO STUN & TURN server.
|
||||
///
|
||||
@@ -33,7 +38,7 @@ use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method,
|
||||
///
|
||||
/// Additionally, we assume to have complete ownership over the port range `LOWEST_PORT` - `HIGHEST_PORT`.
|
||||
pub struct Server<R = ThreadRng> {
|
||||
decoder: MessageDecoder<Attribute>,
|
||||
decoder: client_message::Decoder,
|
||||
encoder: MessageEncoder<Attribute>,
|
||||
|
||||
public_ip4_address: Ipv4Addr,
|
||||
@@ -108,14 +113,6 @@ const HIGHEST_PORT: u16 = 65535;
|
||||
/// The maximum number of ports available for allocation.
|
||||
const MAX_AVAILABLE_PORTS: u16 = HIGHEST_PORT - LOWEST_PORT;
|
||||
|
||||
/// The maximum lifetime of an allocation.
|
||||
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
|
||||
|
||||
/// The default lifetime of an allocation.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-allocations-2>.
|
||||
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
|
||||
|
||||
/// The duration of a channel binding.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
|
||||
@@ -149,71 +146,58 @@ where
|
||||
/// Process the bytes received from a client.
|
||||
///
|
||||
/// After calling this method, you should call [`Server::next_command`] until it returns `None`.
|
||||
pub fn handle_client_input(
|
||||
&mut self,
|
||||
bytes: &[u8],
|
||||
sender: SocketAddr,
|
||||
now: Instant,
|
||||
) -> Result<()> {
|
||||
pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: Instant) {
|
||||
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
|
||||
let hex_bytes = hex::encode(bytes);
|
||||
tracing::trace!(target: "wire", r#"Input::client("{sender}","{hex_bytes}")"#);
|
||||
}
|
||||
|
||||
// De-multiplex as per <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
|
||||
match bytes.first() {
|
||||
Some(0..=3) => {
|
||||
let Ok(message) = self.decoder.decode_from_bytes(bytes)? else {
|
||||
tracing::warn!(target: "relay", "received broken STUN message from {sender}");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
tracing::trace!(target: "relay", "Received {} {} from {sender}", message.method().as_str(), message.class().as_str());
|
||||
|
||||
self.dispatch_stun_message(message, sender, now, |server, message, sender, now| {
|
||||
use MessageClass::*;
|
||||
match (message.method(), message.class()) {
|
||||
(BINDING, Request) => {
|
||||
server.handle_binding_request(message, sender);
|
||||
Ok(())
|
||||
}
|
||||
(ALLOCATE, Request) => server.handle_allocate_request(message, sender, now),
|
||||
(REFRESH, Request) => server.handle_refresh_request(message, sender, now),
|
||||
(CHANNEL_BIND, Request) => {
|
||||
server.handle_channel_bind_request(message, sender, now)
|
||||
}
|
||||
(CREATE_PERMISSION, Request) => {
|
||||
server.handle_create_permission_request(message, sender, now)
|
||||
}
|
||||
(_, Indication) => {
|
||||
tracing::trace!(target: "relay", "Indications are not yet implemented");
|
||||
|
||||
Err(ErrorCode::from(BadRequest))
|
||||
}
|
||||
_ => Err(ErrorCode::from(BadRequest)),
|
||||
}
|
||||
});
|
||||
let result = match self.decoder.decode(bytes) {
|
||||
Ok(Ok(ClientMessage::Allocate(request))) => {
|
||||
self.handle_allocate_request(request, sender, now)
|
||||
}
|
||||
Ok(Ok(ClientMessage::Refresh(request))) => {
|
||||
self.handle_refresh_request(request, sender, now)
|
||||
}
|
||||
Ok(Ok(ClientMessage::ChannelBind(request))) => {
|
||||
self.handle_channel_bind_request(request, sender, now)
|
||||
}
|
||||
Ok(Ok(ClientMessage::CreatePermission(request))) => {
|
||||
self.handle_create_permission_request(request, sender, now)
|
||||
}
|
||||
Ok(Ok(ClientMessage::Binding(request))) => {
|
||||
self.handle_binding_request(request, sender);
|
||||
return;
|
||||
}
|
||||
Ok(Ok(ClientMessage::ChannelData(msg))) => {
|
||||
self.handle_channel_data_message(msg, sender, now);
|
||||
return;
|
||||
}
|
||||
Some(64..=79) => {
|
||||
let (channel, data) = match channel_data::parse(bytes) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
target: "relay",
|
||||
"failed to parse channel data message: {e:#}"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
self.handle_channel_data_message(channel, data, sender, now);
|
||||
}
|
||||
_ => {
|
||||
tracing::trace!(target: "relay", "Received unknown message from {sender}");
|
||||
}
|
||||
// Could parse the bytes but message was semantically invalid (like missing attribute).
|
||||
Ok(Err(error_code)) => Err(error_code),
|
||||
|
||||
// Parsing the bytes failed.
|
||||
Err(client_message::Error::BadChannelData(_)) => return,
|
||||
Err(client_message::Error::DecodeStun(_)) => return,
|
||||
Err(client_message::Error::UnknownMessageType(_)) => return,
|
||||
Err(client_message::Error::Eof) => return,
|
||||
};
|
||||
|
||||
let Err(mut error_response) = result else {
|
||||
return;
|
||||
};
|
||||
|
||||
// In case of a 401 response, attach a realm and nonce.
|
||||
if error_response
|
||||
.get_attribute::<ErrorCode>()
|
||||
.map_or(false, |error| error == &ErrorCode::from(Unauthorized))
|
||||
{
|
||||
error_response.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into());
|
||||
error_response.add_attribute(Realm::new("firezone".to_owned()).unwrap().into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
self.send_message(error_response, sender);
|
||||
}
|
||||
|
||||
/// Process the bytes received from an allocation.
|
||||
@@ -264,7 +248,7 @@ where
|
||||
);
|
||||
|
||||
let recipient = *client;
|
||||
let data = channel_data::make(*channel_number, bytes);
|
||||
let data = ChannelData::new(*channel_number, bytes).to_bytes();
|
||||
|
||||
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
|
||||
let hex_bytes = hex::encode(&data);
|
||||
@@ -321,27 +305,7 @@ where
|
||||
self.pending_commands.pop_front()
|
||||
}
|
||||
|
||||
fn dispatch_stun_message(
|
||||
&mut self,
|
||||
message: Message<Attribute>,
|
||||
sender: SocketAddr,
|
||||
now: Instant,
|
||||
handler: impl Fn(&mut Self, Message<Attribute>, SocketAddr, Instant) -> Result<(), ErrorCode>,
|
||||
) {
|
||||
let transaction_id = message.transaction_id();
|
||||
let method = message.method();
|
||||
|
||||
if let Err(e) = handler(self, message, sender, now) {
|
||||
if e.code() == Unauthorized::CODEPOINT {
|
||||
self.send_message(unauthorized(transaction_id, method), sender);
|
||||
return;
|
||||
}
|
||||
|
||||
self.send_message(error_response(transaction_id, method, e), sender);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_binding_request(&mut self, message: Message<Attribute>, sender: SocketAddr) {
|
||||
pub fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) {
|
||||
let mut message = Message::new(
|
||||
MessageClass::SuccessResponse,
|
||||
BINDING,
|
||||
@@ -355,35 +319,27 @@ where
|
||||
/// Handle a TURN allocate request.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-an-allocate-reque> for details.
|
||||
fn handle_allocate_request(
|
||||
pub fn handle_allocate_request(
|
||||
&mut self,
|
||||
message: Message<Attribute>,
|
||||
request: Allocate,
|
||||
sender: SocketAddr,
|
||||
now: Instant,
|
||||
) -> Result<(), ErrorCode> {
|
||||
let _ = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(Unauthorized)?;
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
// TODO: Check validity of message integrity here?
|
||||
|
||||
if self.allocations.contains_key(&sender) {
|
||||
return Err(AllocationMismatch.into());
|
||||
return Err(error_response(AllocationMismatch, &request));
|
||||
}
|
||||
|
||||
if self.allocations_by_port.len() == MAX_AVAILABLE_PORTS as usize {
|
||||
return Err(InsufficientCapacity.into());
|
||||
return Err(error_response(InsufficientCapacity, &request));
|
||||
}
|
||||
|
||||
let requested_transport = message
|
||||
.get_attribute::<RequestedTransport>()
|
||||
.ok_or(BadRequest)?;
|
||||
|
||||
if requested_transport.protocol() != UDP_TRANSPORT {
|
||||
return Err(BadRequest.into());
|
||||
if request.requested_transport().protocol() != UDP_TRANSPORT {
|
||||
return Err(error_response(BadRequest, &request));
|
||||
}
|
||||
|
||||
let requested_lifetime = message.get_attribute::<Lifetime>();
|
||||
|
||||
let effective_lifetime = compute_effective_lifetime(requested_lifetime);
|
||||
let effective_lifetime = request.effective_lifetime();
|
||||
|
||||
// TODO: Do we need to handle DONT-FRAGMENT?
|
||||
// TODO: Do we need to handle EVEN/ODD-PORT?
|
||||
@@ -393,7 +349,7 @@ where
|
||||
let mut message = Message::new(
|
||||
MessageClass::SuccessResponse,
|
||||
ALLOCATE,
|
||||
message.transaction_id(),
|
||||
request.transaction_id(),
|
||||
);
|
||||
|
||||
let ip4_relay_address = self.public_relay_address_for_port(allocation.port);
|
||||
@@ -430,24 +386,21 @@ where
|
||||
/// Handle a TURN refresh request.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-a-refresh-request> for details.
|
||||
fn handle_refresh_request(
|
||||
pub fn handle_refresh_request(
|
||||
&mut self,
|
||||
message: Message<Attribute>,
|
||||
request: Refresh,
|
||||
sender: SocketAddr,
|
||||
now: Instant,
|
||||
) -> Result<(), ErrorCode> {
|
||||
let _ = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(Unauthorized)?;
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
// TODO: Check validity of message integrity here?
|
||||
|
||||
// TODO: Verify that this is the correct error code.
|
||||
let allocation = self
|
||||
.allocations
|
||||
.get_mut(&sender)
|
||||
.ok_or(ErrorCode::from(AllocationMismatch))?;
|
||||
.ok_or(error_response(AllocationMismatch, &request))?;
|
||||
|
||||
let requested_lifetime = message.get_attribute::<Lifetime>();
|
||||
let effective_lifetime = compute_effective_lifetime(requested_lifetime);
|
||||
let effective_lifetime = request.effective_lifetime();
|
||||
|
||||
if effective_lifetime.lifetime().is_zero() {
|
||||
let port = allocation.port;
|
||||
@@ -457,7 +410,7 @@ where
|
||||
self.allocations.remove(&sender);
|
||||
self.allocations_by_port.remove(&port);
|
||||
self.send_message(
|
||||
refresh_success_response(effective_lifetime, message.transaction_id()),
|
||||
refresh_success_response(effective_lifetime, request.transaction_id()),
|
||||
sender,
|
||||
);
|
||||
|
||||
@@ -486,7 +439,7 @@ where
|
||||
deadline: wake_deadline,
|
||||
});
|
||||
self.send_message(
|
||||
refresh_success_response(effective_lifetime, message.transaction_id()),
|
||||
refresh_success_response(effective_lifetime, request.transaction_id()),
|
||||
sender,
|
||||
);
|
||||
|
||||
@@ -496,49 +449,40 @@ where
|
||||
/// Handle a TURN channel bind request.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-a-channelbind-req> for details.
|
||||
fn handle_channel_bind_request(
|
||||
pub fn handle_channel_bind_request(
|
||||
&mut self,
|
||||
message: Message<Attribute>,
|
||||
request: ChannelBind,
|
||||
sender: SocketAddr,
|
||||
now: Instant,
|
||||
) -> Result<(), ErrorCode> {
|
||||
let _ = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(Unauthorized)?;
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
// TODO: Check validity of message integrity here?
|
||||
|
||||
let allocation = self
|
||||
.allocations
|
||||
.get_mut(&sender)
|
||||
.ok_or(ErrorCode::from(AllocationMismatch))?;
|
||||
.ok_or(error_response(AllocationMismatch, &request))?;
|
||||
|
||||
let requested_channel = message
|
||||
.get_attribute::<ChannelNumber>()
|
||||
.ok_or(ErrorCode::from(BadRequest))?
|
||||
.value();
|
||||
|
||||
let peer_address = message
|
||||
.get_attribute::<XorPeerAddress>()
|
||||
.ok_or(ErrorCode::from(BadRequest))?
|
||||
.address();
|
||||
let requested_channel = request.channel_number().value();
|
||||
let peer_address = request.xor_peer_address().address();
|
||||
|
||||
// Note: `channel_number` is enforced to be in the correct range.
|
||||
|
||||
// Check that our allocation can handle the requested peer addr.
|
||||
if !allocation.can_relay_to(peer_address) {
|
||||
return Err(ErrorCode::from(PeerAddressFamilyMismatch));
|
||||
return Err(error_response(PeerAddressFamilyMismatch, &request));
|
||||
}
|
||||
|
||||
// Ensure the same address isn't already bound to a different channel.
|
||||
if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) {
|
||||
if number != &requested_channel {
|
||||
return Err(ErrorCode::from(BadRequest));
|
||||
return Err(error_response(BadRequest, &request));
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the channel is not already bound to a different address.
|
||||
if let Some(channel) = self.channels_by_number.get_mut(&requested_channel) {
|
||||
if channel.peer_address != peer_address {
|
||||
return Err(ErrorCode::from(BadRequest));
|
||||
return Err(error_response(BadRequest, &request));
|
||||
}
|
||||
|
||||
// Binding requests for existing channels act as a refresh for the binding.
|
||||
@@ -552,7 +496,7 @@ where
|
||||
TimedAction::UnbindChannel(requested_channel),
|
||||
);
|
||||
self.send_message(
|
||||
channel_bind_success_response(message.transaction_id()),
|
||||
channel_bind_success_response(request.transaction_id()),
|
||||
sender,
|
||||
);
|
||||
|
||||
@@ -567,7 +511,7 @@ where
|
||||
let allocation_id = allocation.id;
|
||||
self.create_channel_binding(requested_channel, peer_address, allocation_id, now);
|
||||
self.send_message(
|
||||
channel_bind_success_response(message.transaction_id()),
|
||||
channel_bind_success_response(request.transaction_id()),
|
||||
sender,
|
||||
);
|
||||
|
||||
@@ -582,15 +526,14 @@ where
|
||||
///
|
||||
/// This TURN server implementation does not support relaying data other than through channels.
|
||||
/// Thus, creating a permission is a no-op that always succeeds.
|
||||
fn handle_create_permission_request(
|
||||
pub fn handle_create_permission_request(
|
||||
&mut self,
|
||||
message: Message<Attribute>,
|
||||
message: CreatePermission,
|
||||
sender: SocketAddr,
|
||||
_: Instant,
|
||||
) -> Result<(), ErrorCode> {
|
||||
let _ = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(Unauthorized)?;
|
||||
) -> Result<(), Message<Attribute>> {
|
||||
// TODO: Check validity of message integrity here?
|
||||
|
||||
self.send_message(
|
||||
create_permission_success_response(message.transaction_id()),
|
||||
sender,
|
||||
@@ -599,13 +542,15 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_channel_data_message(
|
||||
pub fn handle_channel_data_message(
|
||||
&mut self,
|
||||
channel_number: u16,
|
||||
data: &[u8],
|
||||
message: ChannelData,
|
||||
sender: SocketAddr,
|
||||
_: Instant,
|
||||
) {
|
||||
let channel_number = message.channel();
|
||||
let data = message.data();
|
||||
|
||||
let Some(channel) = self.channels_by_number.get(&channel_number) else {
|
||||
tracing::debug!(target: "relay", "Channel {channel_number} does not exist, refusing to forward data");
|
||||
return;
|
||||
@@ -834,37 +779,45 @@ enum TimedAction {
|
||||
DeleteChannel(u16),
|
||||
}
|
||||
|
||||
/// Computes the effective lifetime of an allocation.
|
||||
fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime {
|
||||
let Some(requested) = requested_lifetime else {
|
||||
return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap();
|
||||
};
|
||||
|
||||
let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME);
|
||||
|
||||
Lifetime::new(effective_lifetime).unwrap()
|
||||
}
|
||||
|
||||
fn error_response(
|
||||
transaction_id: TransactionId,
|
||||
method: Method,
|
||||
error_code: ErrorCode,
|
||||
error_code: impl Into<ErrorCode>,
|
||||
request: &impl StunRequest,
|
||||
) -> Message<Attribute> {
|
||||
let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id);
|
||||
message.add_attribute(error_code.into());
|
||||
let mut message = Message::new(
|
||||
MessageClass::ErrorResponse,
|
||||
request.method(),
|
||||
request.transaction_id(),
|
||||
);
|
||||
message.add_attribute(Attribute::from(error_code.into()));
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
fn unauthorized(transaction_id: TransactionId, method: Method) -> Message<Attribute> {
|
||||
let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id);
|
||||
message.add_attribute(ErrorCode::from(Unauthorized).into());
|
||||
message.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into());
|
||||
message.add_attribute(Realm::new("firezone".to_owned()).unwrap().into());
|
||||
|
||||
message
|
||||
/// Private helper trait to make [`error_response`] more ergonomic to use.
|
||||
trait StunRequest {
|
||||
fn transaction_id(&self) -> TransactionId;
|
||||
fn method(&self) -> Method;
|
||||
}
|
||||
|
||||
macro_rules! impl_stun_request_for {
|
||||
($t:ty, $m:expr) => {
|
||||
impl StunRequest for $t {
|
||||
fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id()
|
||||
}
|
||||
|
||||
fn method(&self) -> Method {
|
||||
$m
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_stun_request_for!(Allocate, ALLOCATE);
|
||||
impl_stun_request_for!(ChannelBind, CHANNEL_BIND);
|
||||
impl_stun_request_for!(CreatePermission, CREATE_PERMISSION);
|
||||
impl_stun_request_for!(Refresh, REFRESH);
|
||||
|
||||
// Define an enum of all attributes that we care about for our server.
|
||||
stun_codec::define_attribute_enums!(
|
||||
Attribute,
|
||||
@@ -884,17 +837,3 @@ stun_codec::define_attribute_enums!(
|
||||
Username
|
||||
]
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn requested_lifetime_is_capped_at_max_lifetime() {
|
||||
let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap();
|
||||
|
||||
let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime));
|
||||
|
||||
assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +1,61 @@
|
||||
use anyhow::{bail, Result};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use std::io;
|
||||
|
||||
pub(crate) fn make(channel: u16, data: &[u8]) -> Vec<u8> {
|
||||
let mut message = BytesMut::with_capacity(2 + 2 + data.len());
|
||||
|
||||
message.put_u16(channel);
|
||||
message.put_u16(data.len() as u16);
|
||||
message.put_slice(data);
|
||||
|
||||
message.freeze().to_vec()
|
||||
pub struct ChannelData<'a> {
|
||||
channel: u16,
|
||||
data: &'a [u8],
|
||||
}
|
||||
|
||||
pub(crate) fn parse(data: &[u8]) -> Result<(u16, &[u8])> {
|
||||
if data.len() < 4 {
|
||||
bail!("must have at least 4 bytes for channel data message")
|
||||
impl<'a> ChannelData<'a> {
|
||||
pub fn parse(data: &'a [u8]) -> Result<Self, io::Error> {
|
||||
if data.len() < 4 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"channel data messages are at least 4 bytes long",
|
||||
));
|
||||
}
|
||||
|
||||
let channel_number = u16::from_be_bytes([data[0], data[1]]);
|
||||
let length = u16::from_be_bytes([data[2], data[3]]);
|
||||
|
||||
let actual_payload_length = data.len() - 4;
|
||||
|
||||
if actual_payload_length != length as usize {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"channel data message specified {length} bytes but got {actual_payload_length}"
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(ChannelData {
|
||||
channel: channel_number,
|
||||
data: &data[4..],
|
||||
})
|
||||
}
|
||||
|
||||
let channel_number = u16::from_be_bytes([data[0], data[1]]);
|
||||
let length = u16::from_be_bytes([data[2], data[3]]);
|
||||
pub fn new(channel: u16, data: &'a [u8]) -> Self {
|
||||
ChannelData { channel, data }
|
||||
}
|
||||
|
||||
anyhow::ensure!((data.len() - 4) == length as usize);
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
let mut message = BytesMut::with_capacity(2 + 2 + self.data.len());
|
||||
|
||||
Ok((channel_number, &data[4..]))
|
||||
message.put_u16(self.channel);
|
||||
message.put_u16(self.data.len() as u16);
|
||||
message.put_slice(self.data);
|
||||
|
||||
message.freeze().into()
|
||||
}
|
||||
|
||||
pub fn channel(&self) -> u16 {
|
||||
self.channel
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &[u8] {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: tests
|
||||
|
||||
400
rust/relay/src/server/client_message.rs
Normal file
400
rust/relay/src/server/client_message.rs
Normal file
@@ -0,0 +1,400 @@
|
||||
use crate::server::channel_data::ChannelData;
|
||||
use crate::Attribute;
|
||||
use bytecodec::DecodeExt;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use stun_codec::rfc5389::attributes::{ErrorCode, MessageIntegrity, Username};
|
||||
use stun_codec::rfc5389::errors::{BadRequest, Unauthorized};
|
||||
use stun_codec::rfc5389::methods::BINDING;
|
||||
use stun_codec::rfc5766::attributes::{
|
||||
ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress,
|
||||
};
|
||||
use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH};
|
||||
use stun_codec::{BrokenMessage, Message, MessageClass, TransactionId};
|
||||
|
||||
/// The maximum lifetime of an allocation.
|
||||
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
|
||||
|
||||
/// The default lifetime of an allocation.
|
||||
///
|
||||
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-allocations-2>.
|
||||
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Decoder {
|
||||
stun_message_decoder: stun_codec::MessageDecoder<Attribute>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn decode<'a>(
|
||||
&mut self,
|
||||
input: &'a [u8],
|
||||
) -> Result<Result<ClientMessage<'a>, Message<Attribute>>, Error> {
|
||||
// De-multiplex as per <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
|
||||
match input.first() {
|
||||
Some(0..=3) => {
|
||||
let message = self.stun_message_decoder.decode_from_bytes(input)??;
|
||||
|
||||
use MessageClass::*;
|
||||
match (message.method(), message.class()) {
|
||||
(BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))),
|
||||
(ALLOCATE, Request) => {
|
||||
Ok(Allocate::parse(&message).map(ClientMessage::Allocate))
|
||||
}
|
||||
(REFRESH, Request) => Ok(Refresh::parse(&message).map(ClientMessage::Refresh)),
|
||||
(CHANNEL_BIND, Request) => {
|
||||
Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind))
|
||||
}
|
||||
(CREATE_PERMISSION, Request) => {
|
||||
Ok(CreatePermission::parse(&message).map(ClientMessage::CreatePermission))
|
||||
}
|
||||
(_, Request) => Ok(Err(bad_request(&message))),
|
||||
(method, class) => {
|
||||
Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
format!(
|
||||
"handling method {} and {class:?} is not implemented",
|
||||
method.as_u16()
|
||||
),
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))),
|
||||
Some(other) => Err(Error::UnknownMessageType(*other)),
|
||||
None => Err(Error::Eof),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ClientMessage<'a> {
|
||||
ChannelData(ChannelData<'a>),
|
||||
Binding(Binding),
|
||||
Allocate(Allocate),
|
||||
Refresh(Refresh),
|
||||
ChannelBind(ChannelBind),
|
||||
CreatePermission(CreatePermission),
|
||||
}
|
||||
|
||||
pub struct Binding {
|
||||
transaction_id: TransactionId,
|
||||
}
|
||||
|
||||
impl Binding {
|
||||
pub fn new(transaction_id: TransactionId) -> Self {
|
||||
Self { transaction_id }
|
||||
}
|
||||
|
||||
pub fn parse(message: &Message<Attribute>) -> Self {
|
||||
let transaction_id = message.transaction_id();
|
||||
|
||||
Binding { transaction_id }
|
||||
}
|
||||
|
||||
pub fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Allocate {
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
requested_transport: RequestedTransport,
|
||||
lifetime: Option<Lifetime>,
|
||||
}
|
||||
|
||||
impl Allocate {
|
||||
pub fn new(
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
requested_transport: RequestedTransport,
|
||||
lifetime: Option<Lifetime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
requested_transport,
|
||||
lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
|
||||
let transaction_id = message.transaction_id();
|
||||
let message_integrity = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(unauthorized(message))?
|
||||
.clone();
|
||||
let requested_transport = message
|
||||
.get_attribute::<RequestedTransport>()
|
||||
.ok_or(bad_request(message))?
|
||||
.clone();
|
||||
let lifetime = message.get_attribute::<Lifetime>().cloned();
|
||||
|
||||
Ok(Allocate {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
requested_transport,
|
||||
lifetime,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id
|
||||
}
|
||||
|
||||
pub fn message_integrity(&self) -> &MessageIntegrity {
|
||||
&self.message_integrity
|
||||
}
|
||||
|
||||
pub fn requested_transport(&self) -> &RequestedTransport {
|
||||
&self.requested_transport
|
||||
}
|
||||
|
||||
pub fn effective_lifetime(&self) -> Lifetime {
|
||||
compute_effective_lifetime(self.lifetime.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Refresh {
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
lifetime: Option<Lifetime>,
|
||||
}
|
||||
|
||||
impl Refresh {
|
||||
pub fn new(
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
lifetime: Option<Lifetime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
|
||||
let transaction_id = message.transaction_id();
|
||||
let message_integrity = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(unauthorized(message))?
|
||||
.clone();
|
||||
let lifetime = message.get_attribute::<Lifetime>().cloned();
|
||||
|
||||
Ok(Refresh {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
lifetime,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id
|
||||
}
|
||||
|
||||
pub fn message_integrity(&self) -> &MessageIntegrity {
|
||||
&self.message_integrity
|
||||
}
|
||||
|
||||
pub fn effective_lifetime(&self) -> Lifetime {
|
||||
compute_effective_lifetime(self.lifetime.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ChannelBind {
|
||||
transaction_id: TransactionId,
|
||||
channel_number: ChannelNumber,
|
||||
message_integrity: MessageIntegrity,
|
||||
xor_peer_address: XorPeerAddress,
|
||||
username: Username,
|
||||
}
|
||||
|
||||
impl ChannelBind {
|
||||
pub fn new(
|
||||
transaction_id: TransactionId,
|
||||
channel_number: ChannelNumber,
|
||||
message_integrity: MessageIntegrity,
|
||||
username: Username,
|
||||
xor_peer_address: XorPeerAddress,
|
||||
) -> Self {
|
||||
Self {
|
||||
transaction_id,
|
||||
channel_number,
|
||||
message_integrity,
|
||||
xor_peer_address,
|
||||
username,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
|
||||
let transaction_id = message.transaction_id();
|
||||
let channel_number = message
|
||||
.get_attribute::<ChannelNumber>()
|
||||
.copied()
|
||||
.ok_or(bad_request(message))?;
|
||||
let message_integrity = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(unauthorized(message))?
|
||||
.clone();
|
||||
let username = message
|
||||
.get_attribute::<Username>()
|
||||
.ok_or(bad_request(message))?
|
||||
.clone();
|
||||
let xor_peer_address = message
|
||||
.get_attribute::<XorPeerAddress>()
|
||||
.ok_or(bad_request(message))?
|
||||
.clone();
|
||||
|
||||
Ok(ChannelBind {
|
||||
transaction_id,
|
||||
channel_number,
|
||||
message_integrity,
|
||||
xor_peer_address,
|
||||
username,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id
|
||||
}
|
||||
|
||||
pub fn channel_number(&self) -> ChannelNumber {
|
||||
self.channel_number
|
||||
}
|
||||
|
||||
pub fn message_integrity(&self) -> &MessageIntegrity {
|
||||
&self.message_integrity
|
||||
}
|
||||
|
||||
pub fn xor_peer_address(&self) -> &XorPeerAddress {
|
||||
&self.xor_peer_address
|
||||
}
|
||||
|
||||
pub fn username(&self) -> &Username {
|
||||
&self.username
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CreatePermission {
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
username: Username,
|
||||
}
|
||||
|
||||
impl CreatePermission {
|
||||
pub fn new(
|
||||
transaction_id: TransactionId,
|
||||
message_integrity: MessageIntegrity,
|
||||
username: Username,
|
||||
) -> Self {
|
||||
Self {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
username,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
|
||||
let transaction_id = message.transaction_id();
|
||||
let message_integrity = message
|
||||
.get_attribute::<MessageIntegrity>()
|
||||
.ok_or(unauthorized(message))?
|
||||
.clone();
|
||||
let username = message
|
||||
.get_attribute::<Username>()
|
||||
.ok_or(bad_request(message))?
|
||||
.clone();
|
||||
|
||||
Ok(CreatePermission {
|
||||
transaction_id,
|
||||
message_integrity,
|
||||
username,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transaction_id(&self) -> TransactionId {
|
||||
self.transaction_id
|
||||
}
|
||||
|
||||
pub fn message_integrity(&self) -> &MessageIntegrity {
|
||||
&self.message_integrity
|
||||
}
|
||||
|
||||
pub fn username(&self) -> &Username {
|
||||
&self.username
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the effective lifetime of an allocation.
|
||||
fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime {
|
||||
let Some(requested) = requested_lifetime else {
|
||||
return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap();
|
||||
};
|
||||
|
||||
let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME);
|
||||
|
||||
Lifetime::new(effective_lifetime).unwrap()
|
||||
}
|
||||
|
||||
fn bad_request(message: &Message<Attribute>) -> Message<Attribute> {
|
||||
let mut message = Message::new(
|
||||
MessageClass::ErrorResponse,
|
||||
message.method(),
|
||||
message.transaction_id(),
|
||||
);
|
||||
message.add_attribute(ErrorCode::from(BadRequest).into());
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
fn unauthorized(message: &Message<Attribute>) -> Message<Attribute> {
|
||||
let mut message = Message::new(
|
||||
MessageClass::ErrorResponse,
|
||||
message.method(),
|
||||
message.transaction_id(),
|
||||
);
|
||||
message.add_attribute(ErrorCode::from(Unauthorized).into());
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
BadChannelData(io::Error),
|
||||
DecodeStun(bytecodec::Error),
|
||||
UnknownMessageType(u8),
|
||||
Eof,
|
||||
}
|
||||
|
||||
impl From<BrokenMessage> for Error {
|
||||
fn from(msg: BrokenMessage) -> Self {
|
||||
Error::DecodeStun(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<bytecodec::Error> for Error {
|
||||
fn from(error: bytecodec::Error) -> Self {
|
||||
Error::DecodeStun(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(error: io::Error) -> Self {
|
||||
Error::BadChannelData(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn requested_lifetime_is_capped_at_max_lifetime() {
|
||||
let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap();
|
||||
|
||||
let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime));
|
||||
|
||||
assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use bytecodec::EncodeExt;
|
||||
use bytecodec::{DecodeExt, EncodeExt};
|
||||
use hex_literal::hex;
|
||||
use relay::{AllocationId, Attribute, Command, Server};
|
||||
use std::collections::HashMap;
|
||||
@@ -6,7 +6,9 @@ use std::time::{Duration, Instant};
|
||||
use stun_codec::rfc5389::attributes::{MessageIntegrity, Realm, Username};
|
||||
use stun_codec::rfc5766::attributes::Lifetime;
|
||||
use stun_codec::rfc5766::methods::REFRESH;
|
||||
use stun_codec::{Message, MessageClass, MessageEncoder, TransactionId};
|
||||
use stun_codec::{
|
||||
DecodedMessage, Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn stun_binding_request() {
|
||||
@@ -123,11 +125,6 @@ fn ping_pong_relay() {
|
||||
Output::send_message("127.0.0.1:42677","010800002112a442dc5c115f6b727e25a54b55d3")
|
||||
]),
|
||||
(
|
||||
Input::client("127.0.0.1:42677","001600242112a4421b6fb3ce9cefcd57ef3a9edb00130009484f4c4550554e4348000000001200080001f2405e12a443802800041bfe0967", now),
|
||||
&[
|
||||
Output::send_message("127.0.0.1:42677","011600142112a4421b6fb3ce9cefcd57ef3a9edb0009000f00000400426164205265717565737400")
|
||||
]),
|
||||
(
|
||||
Input::client("127.0.0.1:42677","000900542112a4420afbde5aaacfc1e9316beae9001200080001f2405e12a443000c000440000000000600047465737400140008666972657a6f6e6500150006666f6f626172000000080014aca01c6cdc1fc5339a309e5bccac3df5c903e33e802800041fe4b79b", now),
|
||||
&[
|
||||
Output::send_message("127.0.0.1:42677","010900002112a4420afbde5aaacfc1e9316beae9")
|
||||
@@ -159,7 +156,7 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) {
|
||||
let input = hex::decode(data).unwrap();
|
||||
let from = from.parse().unwrap();
|
||||
|
||||
server.handle_client_input(&input, from, *now).unwrap();
|
||||
server.handle_client_input(&input, from, *now);
|
||||
}
|
||||
Input::Time(now) => {
|
||||
server.handle_deadline_reached(*now);
|
||||
@@ -173,13 +170,30 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) {
|
||||
}
|
||||
|
||||
for expected_output in *output {
|
||||
let actual_output = server
|
||||
.next_command()
|
||||
.unwrap_or_else(|| panic!("no commands produced but expected {expected_output:?}"));
|
||||
let Some(actual_output) = server.next_command() else {
|
||||
let msg = match expected_output {
|
||||
Output::SendMessage((recipient, bytes)) => format!("to send message {:?} to {recipient}", parse_hex_message(bytes)),
|
||||
Output::Forward((ip, data, port)) => format!("forward '{data}' to {ip} on port {port}"),
|
||||
Output::Wake(instant) => format!("to be woken at {instant:?}"),
|
||||
Output::CreateAllocation(port) => format!("to create allocation on port {port}"),
|
||||
Output::ExpireAllocation(port) => format!("to free allocation on port {port}"),
|
||||
};
|
||||
|
||||
panic!("No commands produced but expected {msg}");
|
||||
};
|
||||
|
||||
match (expected_output, actual_output) {
|
||||
(Output::SendMessage((to, bytes)), Command::SendMessage { payload, recipient }) => {
|
||||
assert_eq!(*bytes, hex::encode(payload));
|
||||
let expected_bytes = hex::decode(bytes).unwrap();
|
||||
|
||||
if expected_bytes != payload {
|
||||
let expected_message =
|
||||
format!("{:?}", parse_message(expected_bytes.as_ref()));
|
||||
let actual_message = format!("{:?}", parse_message(payload.as_ref()));
|
||||
|
||||
difference::assert_diff!(&expected_message, &actual_message, "\n", 0);
|
||||
}
|
||||
|
||||
assert_eq!(recipient, to.parse().unwrap());
|
||||
}
|
||||
(
|
||||
@@ -240,6 +254,15 @@ where
|
||||
hex::encode(MessageEncoder::new().encode_into_bytes(message).unwrap())
|
||||
}
|
||||
|
||||
fn parse_hex_message(message: &str) -> DecodedMessage<Attribute> {
|
||||
let message = hex::decode(message).unwrap();
|
||||
MessageDecoder::new().decode_from_bytes(&message).unwrap()
|
||||
}
|
||||
|
||||
fn parse_message(message: &[u8]) -> DecodedMessage<Attribute> {
|
||||
MessageDecoder::new().decode_from_bytes(message).unwrap()
|
||||
}
|
||||
|
||||
enum Input {
|
||||
Client(Ip, Bytes, Instant),
|
||||
Peer(Ip, Bytes, u16),
|
||||
|
||||
Reference in New Issue
Block a user