refactor(relay): introduce type-safe Server APIs (#1630)

We introduce dedicated types for each message that the `Server` can
handle. This allows us to make the functions public because the
type-system now guarantees that those are either parsed from bytes or
constructed with the correct data.

The latter will be useful to write tests against a richer API.
This commit is contained in:
Thomas Eizinger
2023-05-31 15:18:20 +01:00
committed by GitHub
parent 37a2d7b7f5
commit d27856a8f1
7 changed files with 618 additions and 214 deletions

7
rust/Cargo.lock generated
View File

@@ -603,6 +603,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "difference"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198"
[[package]]
name = "digest"
version = "0.9.0"
@@ -1525,6 +1531,7 @@ dependencies = [
"anyhow",
"bytecodec",
"bytes",
"difference",
"env_logger",
"futures",
"hex",

View File

@@ -20,3 +20,4 @@ bytes = "1.4.0"
[dev-dependencies]
webrtc = "0.7.2"
redis = { version = "0.23.0", default-features = false, features = ["tokio-comp"] }
difference = "2.0.0"

View File

@@ -190,7 +190,7 @@ impl Eventloop {
// Priority 6: Accept new allocations / answer STUN requests etc
if let Poll::Ready((buffer, sender)) = self.ip4_socket.poll_recv(cx)? {
self.server
.handle_client_input(buffer.filled(), sender, Instant::now())?;
.handle_client_input(buffer.filled(), sender, Instant::now());
continue; // Handle potentially new commands.
}

View File

@@ -1,10 +1,15 @@
mod channel_data;
mod client_message;
use crate::rfc8656::PeerAddressFamilyMismatch;
use crate::server::channel_data::ChannelData;
use crate::server::client_message::{
Allocate, Binding, ChannelBind, ClientMessage, CreatePermission, Refresh,
};
use crate::stun_codec_ext::{MessageClassExt, MethodExt};
use crate::TimeEvents;
use anyhow::Result;
use bytecodec::{DecodeExt, EncodeExt};
use bytecodec::EncodeExt;
use core::fmt;
use rand::rngs::mock::StepRng;
use rand::rngs::ThreadRng;
@@ -23,7 +28,7 @@ use stun_codec::rfc5766::attributes::{
};
use stun_codec::rfc5766::errors::{AllocationMismatch, InsufficientCapacity};
use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH};
use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method, TransactionId};
use stun_codec::{Message, MessageClass, MessageEncoder, Method, TransactionId};
/// A sans-IO STUN & TURN server.
///
@@ -33,7 +38,7 @@ use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method,
///
/// Additionally, we assume to have complete ownership over the port range `LOWEST_PORT` - `HIGHEST_PORT`.
pub struct Server<R = ThreadRng> {
decoder: MessageDecoder<Attribute>,
decoder: client_message::Decoder,
encoder: MessageEncoder<Attribute>,
public_ip4_address: Ipv4Addr,
@@ -108,14 +113,6 @@ const HIGHEST_PORT: u16 = 65535;
/// The maximum number of ports available for allocation.
const MAX_AVAILABLE_PORTS: u16 = HIGHEST_PORT - LOWEST_PORT;
/// The maximum lifetime of an allocation.
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
/// The default lifetime of an allocation.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-allocations-2>.
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
/// The duration of a channel binding.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
@@ -149,71 +146,58 @@ where
/// Process the bytes received from a client.
///
/// After calling this method, you should call [`Server::next_command`] until it returns `None`.
pub fn handle_client_input(
&mut self,
bytes: &[u8],
sender: SocketAddr,
now: Instant,
) -> Result<()> {
pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: Instant) {
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
let hex_bytes = hex::encode(bytes);
tracing::trace!(target: "wire", r#"Input::client("{sender}","{hex_bytes}")"#);
}
// De-multiplex as per <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
match bytes.first() {
Some(0..=3) => {
let Ok(message) = self.decoder.decode_from_bytes(bytes)? else {
tracing::warn!(target: "relay", "received broken STUN message from {sender}");
return Ok(());
};
tracing::trace!(target: "relay", "Received {} {} from {sender}", message.method().as_str(), message.class().as_str());
self.dispatch_stun_message(message, sender, now, |server, message, sender, now| {
use MessageClass::*;
match (message.method(), message.class()) {
(BINDING, Request) => {
server.handle_binding_request(message, sender);
Ok(())
}
(ALLOCATE, Request) => server.handle_allocate_request(message, sender, now),
(REFRESH, Request) => server.handle_refresh_request(message, sender, now),
(CHANNEL_BIND, Request) => {
server.handle_channel_bind_request(message, sender, now)
}
(CREATE_PERMISSION, Request) => {
server.handle_create_permission_request(message, sender, now)
}
(_, Indication) => {
tracing::trace!(target: "relay", "Indications are not yet implemented");
Err(ErrorCode::from(BadRequest))
}
_ => Err(ErrorCode::from(BadRequest)),
}
});
let result = match self.decoder.decode(bytes) {
Ok(Ok(ClientMessage::Allocate(request))) => {
self.handle_allocate_request(request, sender, now)
}
Ok(Ok(ClientMessage::Refresh(request))) => {
self.handle_refresh_request(request, sender, now)
}
Ok(Ok(ClientMessage::ChannelBind(request))) => {
self.handle_channel_bind_request(request, sender, now)
}
Ok(Ok(ClientMessage::CreatePermission(request))) => {
self.handle_create_permission_request(request, sender, now)
}
Ok(Ok(ClientMessage::Binding(request))) => {
self.handle_binding_request(request, sender);
return;
}
Ok(Ok(ClientMessage::ChannelData(msg))) => {
self.handle_channel_data_message(msg, sender, now);
return;
}
Some(64..=79) => {
let (channel, data) = match channel_data::parse(bytes) {
Ok(v) => v,
Err(e) => {
tracing::debug!(
target: "relay",
"failed to parse channel data message: {e:#}"
);
return Ok(());
}
};
self.handle_channel_data_message(channel, data, sender, now);
}
_ => {
tracing::trace!(target: "relay", "Received unknown message from {sender}");
}
// Could parse the bytes but message was semantically invalid (like missing attribute).
Ok(Err(error_code)) => Err(error_code),
// Parsing the bytes failed.
Err(client_message::Error::BadChannelData(_)) => return,
Err(client_message::Error::DecodeStun(_)) => return,
Err(client_message::Error::UnknownMessageType(_)) => return,
Err(client_message::Error::Eof) => return,
};
let Err(mut error_response) = result else {
return;
};
// In case of a 401 response, attach a realm and nonce.
if error_response
.get_attribute::<ErrorCode>()
.map_or(false, |error| error == &ErrorCode::from(Unauthorized))
{
error_response.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into());
error_response.add_attribute(Realm::new("firezone".to_owned()).unwrap().into());
}
Ok(())
self.send_message(error_response, sender);
}
/// Process the bytes received from an allocation.
@@ -264,7 +248,7 @@ where
);
let recipient = *client;
let data = channel_data::make(*channel_number, bytes);
let data = ChannelData::new(*channel_number, bytes).to_bytes();
if tracing::enabled!(target: "wire", tracing::Level::TRACE) {
let hex_bytes = hex::encode(&data);
@@ -321,27 +305,7 @@ where
self.pending_commands.pop_front()
}
fn dispatch_stun_message(
&mut self,
message: Message<Attribute>,
sender: SocketAddr,
now: Instant,
handler: impl Fn(&mut Self, Message<Attribute>, SocketAddr, Instant) -> Result<(), ErrorCode>,
) {
let transaction_id = message.transaction_id();
let method = message.method();
if let Err(e) = handler(self, message, sender, now) {
if e.code() == Unauthorized::CODEPOINT {
self.send_message(unauthorized(transaction_id, method), sender);
return;
}
self.send_message(error_response(transaction_id, method, e), sender);
}
}
fn handle_binding_request(&mut self, message: Message<Attribute>, sender: SocketAddr) {
pub fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) {
let mut message = Message::new(
MessageClass::SuccessResponse,
BINDING,
@@ -355,35 +319,27 @@ where
/// Handle a TURN allocate request.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-an-allocate-reque> for details.
fn handle_allocate_request(
pub fn handle_allocate_request(
&mut self,
message: Message<Attribute>,
request: Allocate,
sender: SocketAddr,
now: Instant,
) -> Result<(), ErrorCode> {
let _ = message
.get_attribute::<MessageIntegrity>()
.ok_or(Unauthorized)?;
) -> Result<(), Message<Attribute>> {
// TODO: Check validity of message integrity here?
if self.allocations.contains_key(&sender) {
return Err(AllocationMismatch.into());
return Err(error_response(AllocationMismatch, &request));
}
if self.allocations_by_port.len() == MAX_AVAILABLE_PORTS as usize {
return Err(InsufficientCapacity.into());
return Err(error_response(InsufficientCapacity, &request));
}
let requested_transport = message
.get_attribute::<RequestedTransport>()
.ok_or(BadRequest)?;
if requested_transport.protocol() != UDP_TRANSPORT {
return Err(BadRequest.into());
if request.requested_transport().protocol() != UDP_TRANSPORT {
return Err(error_response(BadRequest, &request));
}
let requested_lifetime = message.get_attribute::<Lifetime>();
let effective_lifetime = compute_effective_lifetime(requested_lifetime);
let effective_lifetime = request.effective_lifetime();
// TODO: Do we need to handle DONT-FRAGMENT?
// TODO: Do we need to handle EVEN/ODD-PORT?
@@ -393,7 +349,7 @@ where
let mut message = Message::new(
MessageClass::SuccessResponse,
ALLOCATE,
message.transaction_id(),
request.transaction_id(),
);
let ip4_relay_address = self.public_relay_address_for_port(allocation.port);
@@ -430,24 +386,21 @@ where
/// Handle a TURN refresh request.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-a-refresh-request> for details.
fn handle_refresh_request(
pub fn handle_refresh_request(
&mut self,
message: Message<Attribute>,
request: Refresh,
sender: SocketAddr,
now: Instant,
) -> Result<(), ErrorCode> {
let _ = message
.get_attribute::<MessageIntegrity>()
.ok_or(Unauthorized)?;
) -> Result<(), Message<Attribute>> {
// TODO: Check validity of message integrity here?
// TODO: Verify that this is the correct error code.
let allocation = self
.allocations
.get_mut(&sender)
.ok_or(ErrorCode::from(AllocationMismatch))?;
.ok_or(error_response(AllocationMismatch, &request))?;
let requested_lifetime = message.get_attribute::<Lifetime>();
let effective_lifetime = compute_effective_lifetime(requested_lifetime);
let effective_lifetime = request.effective_lifetime();
if effective_lifetime.lifetime().is_zero() {
let port = allocation.port;
@@ -457,7 +410,7 @@ where
self.allocations.remove(&sender);
self.allocations_by_port.remove(&port);
self.send_message(
refresh_success_response(effective_lifetime, message.transaction_id()),
refresh_success_response(effective_lifetime, request.transaction_id()),
sender,
);
@@ -486,7 +439,7 @@ where
deadline: wake_deadline,
});
self.send_message(
refresh_success_response(effective_lifetime, message.transaction_id()),
refresh_success_response(effective_lifetime, request.transaction_id()),
sender,
);
@@ -496,49 +449,40 @@ where
/// Handle a TURN channel bind request.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-receiving-a-channelbind-req> for details.
fn handle_channel_bind_request(
pub fn handle_channel_bind_request(
&mut self,
message: Message<Attribute>,
request: ChannelBind,
sender: SocketAddr,
now: Instant,
) -> Result<(), ErrorCode> {
let _ = message
.get_attribute::<MessageIntegrity>()
.ok_or(Unauthorized)?;
) -> Result<(), Message<Attribute>> {
// TODO: Check validity of message integrity here?
let allocation = self
.allocations
.get_mut(&sender)
.ok_or(ErrorCode::from(AllocationMismatch))?;
.ok_or(error_response(AllocationMismatch, &request))?;
let requested_channel = message
.get_attribute::<ChannelNumber>()
.ok_or(ErrorCode::from(BadRequest))?
.value();
let peer_address = message
.get_attribute::<XorPeerAddress>()
.ok_or(ErrorCode::from(BadRequest))?
.address();
let requested_channel = request.channel_number().value();
let peer_address = request.xor_peer_address().address();
// Note: `channel_number` is enforced to be in the correct range.
// Check that our allocation can handle the requested peer addr.
if !allocation.can_relay_to(peer_address) {
return Err(ErrorCode::from(PeerAddressFamilyMismatch));
return Err(error_response(PeerAddressFamilyMismatch, &request));
}
// Ensure the same address isn't already bound to a different channel.
if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) {
if number != &requested_channel {
return Err(ErrorCode::from(BadRequest));
return Err(error_response(BadRequest, &request));
}
}
// Ensure the channel is not already bound to a different address.
if let Some(channel) = self.channels_by_number.get_mut(&requested_channel) {
if channel.peer_address != peer_address {
return Err(ErrorCode::from(BadRequest));
return Err(error_response(BadRequest, &request));
}
// Binding requests for existing channels act as a refresh for the binding.
@@ -552,7 +496,7 @@ where
TimedAction::UnbindChannel(requested_channel),
);
self.send_message(
channel_bind_success_response(message.transaction_id()),
channel_bind_success_response(request.transaction_id()),
sender,
);
@@ -567,7 +511,7 @@ where
let allocation_id = allocation.id;
self.create_channel_binding(requested_channel, peer_address, allocation_id, now);
self.send_message(
channel_bind_success_response(message.transaction_id()),
channel_bind_success_response(request.transaction_id()),
sender,
);
@@ -582,15 +526,14 @@ where
///
/// This TURN server implementation does not support relaying data other than through channels.
/// Thus, creating a permission is a no-op that always succeeds.
fn handle_create_permission_request(
pub fn handle_create_permission_request(
&mut self,
message: Message<Attribute>,
message: CreatePermission,
sender: SocketAddr,
_: Instant,
) -> Result<(), ErrorCode> {
let _ = message
.get_attribute::<MessageIntegrity>()
.ok_or(Unauthorized)?;
) -> Result<(), Message<Attribute>> {
// TODO: Check validity of message integrity here?
self.send_message(
create_permission_success_response(message.transaction_id()),
sender,
@@ -599,13 +542,15 @@ where
Ok(())
}
fn handle_channel_data_message(
pub fn handle_channel_data_message(
&mut self,
channel_number: u16,
data: &[u8],
message: ChannelData,
sender: SocketAddr,
_: Instant,
) {
let channel_number = message.channel();
let data = message.data();
let Some(channel) = self.channels_by_number.get(&channel_number) else {
tracing::debug!(target: "relay", "Channel {channel_number} does not exist, refusing to forward data");
return;
@@ -834,37 +779,45 @@ enum TimedAction {
DeleteChannel(u16),
}
/// Computes the effective lifetime of an allocation.
fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime {
let Some(requested) = requested_lifetime else {
return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap();
};
let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME);
Lifetime::new(effective_lifetime).unwrap()
}
fn error_response(
transaction_id: TransactionId,
method: Method,
error_code: ErrorCode,
error_code: impl Into<ErrorCode>,
request: &impl StunRequest,
) -> Message<Attribute> {
let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id);
message.add_attribute(error_code.into());
let mut message = Message::new(
MessageClass::ErrorResponse,
request.method(),
request.transaction_id(),
);
message.add_attribute(Attribute::from(error_code.into()));
message
}
fn unauthorized(transaction_id: TransactionId, method: Method) -> Message<Attribute> {
let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id);
message.add_attribute(ErrorCode::from(Unauthorized).into());
message.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into());
message.add_attribute(Realm::new("firezone".to_owned()).unwrap().into());
message
/// Private helper trait to make [`error_response`] more ergonomic to use.
trait StunRequest {
fn transaction_id(&self) -> TransactionId;
fn method(&self) -> Method;
}
macro_rules! impl_stun_request_for {
($t:ty, $m:expr) => {
impl StunRequest for $t {
fn transaction_id(&self) -> TransactionId {
self.transaction_id()
}
fn method(&self) -> Method {
$m
}
}
};
}
impl_stun_request_for!(Allocate, ALLOCATE);
impl_stun_request_for!(ChannelBind, CHANNEL_BIND);
impl_stun_request_for!(CreatePermission, CREATE_PERMISSION);
impl_stun_request_for!(Refresh, REFRESH);
// Define an enum of all attributes that we care about for our server.
stun_codec::define_attribute_enums!(
Attribute,
@@ -884,17 +837,3 @@ stun_codec::define_attribute_enums!(
Username
]
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn requested_lifetime_is_capped_at_max_lifetime() {
let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap();
let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime));
assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME)
}
}

View File

@@ -1,27 +1,61 @@
use anyhow::{bail, Result};
use bytes::{BufMut, BytesMut};
use std::io;
pub(crate) fn make(channel: u16, data: &[u8]) -> Vec<u8> {
let mut message = BytesMut::with_capacity(2 + 2 + data.len());
message.put_u16(channel);
message.put_u16(data.len() as u16);
message.put_slice(data);
message.freeze().to_vec()
pub struct ChannelData<'a> {
channel: u16,
data: &'a [u8],
}
pub(crate) fn parse(data: &[u8]) -> Result<(u16, &[u8])> {
if data.len() < 4 {
bail!("must have at least 4 bytes for channel data message")
impl<'a> ChannelData<'a> {
pub fn parse(data: &'a [u8]) -> Result<Self, io::Error> {
if data.len() < 4 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"channel data messages are at least 4 bytes long",
));
}
let channel_number = u16::from_be_bytes([data[0], data[1]]);
let length = u16::from_be_bytes([data[2], data[3]]);
let actual_payload_length = data.len() - 4;
if actual_payload_length != length as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"channel data message specified {length} bytes but got {actual_payload_length}"
),
));
}
Ok(ChannelData {
channel: channel_number,
data: &data[4..],
})
}
let channel_number = u16::from_be_bytes([data[0], data[1]]);
let length = u16::from_be_bytes([data[2], data[3]]);
pub fn new(channel: u16, data: &'a [u8]) -> Self {
ChannelData { channel, data }
}
anyhow::ensure!((data.len() - 4) == length as usize);
pub fn to_bytes(&self) -> Vec<u8> {
let mut message = BytesMut::with_capacity(2 + 2 + self.data.len());
Ok((channel_number, &data[4..]))
message.put_u16(self.channel);
message.put_u16(self.data.len() as u16);
message.put_slice(self.data);
message.freeze().into()
}
pub fn channel(&self) -> u16 {
self.channel
}
pub fn data(&self) -> &[u8] {
self.data
}
}
// TODO: tests

View File

@@ -0,0 +1,400 @@
use crate::server::channel_data::ChannelData;
use crate::Attribute;
use bytecodec::DecodeExt;
use std::io;
use std::time::Duration;
use stun_codec::rfc5389::attributes::{ErrorCode, MessageIntegrity, Username};
use stun_codec::rfc5389::errors::{BadRequest, Unauthorized};
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::rfc5766::attributes::{
ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress,
};
use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH};
use stun_codec::{BrokenMessage, Message, MessageClass, TransactionId};
/// The maximum lifetime of an allocation.
const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600);
/// The default lifetime of an allocation.
///
/// See <https://www.rfc-editor.org/rfc/rfc8656#name-allocations-2>.
const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600);
#[derive(Default)]
pub struct Decoder {
stun_message_decoder: stun_codec::MessageDecoder<Attribute>,
}
impl Decoder {
pub fn decode<'a>(
&mut self,
input: &'a [u8],
) -> Result<Result<ClientMessage<'a>, Message<Attribute>>, Error> {
// De-multiplex as per <https://www.rfc-editor.org/rfc/rfc8656#name-channels-2>.
match input.first() {
Some(0..=3) => {
let message = self.stun_message_decoder.decode_from_bytes(input)??;
use MessageClass::*;
match (message.method(), message.class()) {
(BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))),
(ALLOCATE, Request) => {
Ok(Allocate::parse(&message).map(ClientMessage::Allocate))
}
(REFRESH, Request) => Ok(Refresh::parse(&message).map(ClientMessage::Refresh)),
(CHANNEL_BIND, Request) => {
Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind))
}
(CREATE_PERMISSION, Request) => {
Ok(CreatePermission::parse(&message).map(ClientMessage::CreatePermission))
}
(_, Request) => Ok(Err(bad_request(&message))),
(method, class) => {
Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new(
io::ErrorKind::Unsupported,
format!(
"handling method {} and {class:?} is not implemented",
method.as_u16()
),
))))
}
}
}
Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))),
Some(other) => Err(Error::UnknownMessageType(*other)),
None => Err(Error::Eof),
}
}
}
pub enum ClientMessage<'a> {
ChannelData(ChannelData<'a>),
Binding(Binding),
Allocate(Allocate),
Refresh(Refresh),
ChannelBind(ChannelBind),
CreatePermission(CreatePermission),
}
pub struct Binding {
transaction_id: TransactionId,
}
impl Binding {
pub fn new(transaction_id: TransactionId) -> Self {
Self { transaction_id }
}
pub fn parse(message: &Message<Attribute>) -> Self {
let transaction_id = message.transaction_id();
Binding { transaction_id }
}
pub fn transaction_id(&self) -> TransactionId {
self.transaction_id
}
}
pub struct Allocate {
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
requested_transport: RequestedTransport,
lifetime: Option<Lifetime>,
}
impl Allocate {
pub fn new(
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
requested_transport: RequestedTransport,
lifetime: Option<Lifetime>,
) -> Self {
Self {
transaction_id,
message_integrity,
requested_transport,
lifetime,
}
}
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
let transaction_id = message.transaction_id();
let message_integrity = message
.get_attribute::<MessageIntegrity>()
.ok_or(unauthorized(message))?
.clone();
let requested_transport = message
.get_attribute::<RequestedTransport>()
.ok_or(bad_request(message))?
.clone();
let lifetime = message.get_attribute::<Lifetime>().cloned();
Ok(Allocate {
transaction_id,
message_integrity,
requested_transport,
lifetime,
})
}
pub fn transaction_id(&self) -> TransactionId {
self.transaction_id
}
pub fn message_integrity(&self) -> &MessageIntegrity {
&self.message_integrity
}
pub fn requested_transport(&self) -> &RequestedTransport {
&self.requested_transport
}
pub fn effective_lifetime(&self) -> Lifetime {
compute_effective_lifetime(self.lifetime.as_ref())
}
}
pub struct Refresh {
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
lifetime: Option<Lifetime>,
}
impl Refresh {
pub fn new(
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
lifetime: Option<Lifetime>,
) -> Self {
Self {
transaction_id,
message_integrity,
lifetime,
}
}
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
let transaction_id = message.transaction_id();
let message_integrity = message
.get_attribute::<MessageIntegrity>()
.ok_or(unauthorized(message))?
.clone();
let lifetime = message.get_attribute::<Lifetime>().cloned();
Ok(Refresh {
transaction_id,
message_integrity,
lifetime,
})
}
pub fn transaction_id(&self) -> TransactionId {
self.transaction_id
}
pub fn message_integrity(&self) -> &MessageIntegrity {
&self.message_integrity
}
pub fn effective_lifetime(&self) -> Lifetime {
compute_effective_lifetime(self.lifetime.as_ref())
}
}
pub struct ChannelBind {
transaction_id: TransactionId,
channel_number: ChannelNumber,
message_integrity: MessageIntegrity,
xor_peer_address: XorPeerAddress,
username: Username,
}
impl ChannelBind {
pub fn new(
transaction_id: TransactionId,
channel_number: ChannelNumber,
message_integrity: MessageIntegrity,
username: Username,
xor_peer_address: XorPeerAddress,
) -> Self {
Self {
transaction_id,
channel_number,
message_integrity,
xor_peer_address,
username,
}
}
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
let transaction_id = message.transaction_id();
let channel_number = message
.get_attribute::<ChannelNumber>()
.copied()
.ok_or(bad_request(message))?;
let message_integrity = message
.get_attribute::<MessageIntegrity>()
.ok_or(unauthorized(message))?
.clone();
let username = message
.get_attribute::<Username>()
.ok_or(bad_request(message))?
.clone();
let xor_peer_address = message
.get_attribute::<XorPeerAddress>()
.ok_or(bad_request(message))?
.clone();
Ok(ChannelBind {
transaction_id,
channel_number,
message_integrity,
xor_peer_address,
username,
})
}
pub fn transaction_id(&self) -> TransactionId {
self.transaction_id
}
pub fn channel_number(&self) -> ChannelNumber {
self.channel_number
}
pub fn message_integrity(&self) -> &MessageIntegrity {
&self.message_integrity
}
pub fn xor_peer_address(&self) -> &XorPeerAddress {
&self.xor_peer_address
}
pub fn username(&self) -> &Username {
&self.username
}
}
pub struct CreatePermission {
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
username: Username,
}
impl CreatePermission {
pub fn new(
transaction_id: TransactionId,
message_integrity: MessageIntegrity,
username: Username,
) -> Self {
Self {
transaction_id,
message_integrity,
username,
}
}
pub fn parse(message: &Message<Attribute>) -> Result<Self, Message<Attribute>> {
let transaction_id = message.transaction_id();
let message_integrity = message
.get_attribute::<MessageIntegrity>()
.ok_or(unauthorized(message))?
.clone();
let username = message
.get_attribute::<Username>()
.ok_or(bad_request(message))?
.clone();
Ok(CreatePermission {
transaction_id,
message_integrity,
username,
})
}
pub fn transaction_id(&self) -> TransactionId {
self.transaction_id
}
pub fn message_integrity(&self) -> &MessageIntegrity {
&self.message_integrity
}
pub fn username(&self) -> &Username {
&self.username
}
}
/// Computes the effective lifetime of an allocation.
fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime {
let Some(requested) = requested_lifetime else {
return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap();
};
let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME);
Lifetime::new(effective_lifetime).unwrap()
}
fn bad_request(message: &Message<Attribute>) -> Message<Attribute> {
let mut message = Message::new(
MessageClass::ErrorResponse,
message.method(),
message.transaction_id(),
);
message.add_attribute(ErrorCode::from(BadRequest).into());
message
}
fn unauthorized(message: &Message<Attribute>) -> Message<Attribute> {
let mut message = Message::new(
MessageClass::ErrorResponse,
message.method(),
message.transaction_id(),
);
message.add_attribute(ErrorCode::from(Unauthorized).into());
message
}
#[derive(Debug)]
pub enum Error {
BadChannelData(io::Error),
DecodeStun(bytecodec::Error),
UnknownMessageType(u8),
Eof,
}
impl From<BrokenMessage> for Error {
fn from(msg: BrokenMessage) -> Self {
Error::DecodeStun(msg.into())
}
}
impl From<bytecodec::Error> for Error {
fn from(error: bytecodec::Error) -> Self {
Error::DecodeStun(error)
}
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Error::BadChannelData(error)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn requested_lifetime_is_capped_at_max_lifetime() {
let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap();
let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime));
assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME)
}
}

View File

@@ -1,4 +1,4 @@
use bytecodec::EncodeExt;
use bytecodec::{DecodeExt, EncodeExt};
use hex_literal::hex;
use relay::{AllocationId, Attribute, Command, Server};
use std::collections::HashMap;
@@ -6,7 +6,9 @@ use std::time::{Duration, Instant};
use stun_codec::rfc5389::attributes::{MessageIntegrity, Realm, Username};
use stun_codec::rfc5766::attributes::Lifetime;
use stun_codec::rfc5766::methods::REFRESH;
use stun_codec::{Message, MessageClass, MessageEncoder, TransactionId};
use stun_codec::{
DecodedMessage, Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId,
};
#[test]
fn stun_binding_request() {
@@ -123,11 +125,6 @@ fn ping_pong_relay() {
Output::send_message("127.0.0.1:42677","010800002112a442dc5c115f6b727e25a54b55d3")
]),
(
Input::client("127.0.0.1:42677","001600242112a4421b6fb3ce9cefcd57ef3a9edb00130009484f4c4550554e4348000000001200080001f2405e12a443802800041bfe0967", now),
&[
Output::send_message("127.0.0.1:42677","011600142112a4421b6fb3ce9cefcd57ef3a9edb0009000f00000400426164205265717565737400")
]),
(
Input::client("127.0.0.1:42677","000900542112a4420afbde5aaacfc1e9316beae9001200080001f2405e12a443000c000440000000000600047465737400140008666972657a6f6e6500150006666f6f626172000000080014aca01c6cdc1fc5339a309e5bccac3df5c903e33e802800041fe4b79b", now),
&[
Output::send_message("127.0.0.1:42677","010900002112a4420afbde5aaacfc1e9316beae9")
@@ -159,7 +156,7 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) {
let input = hex::decode(data).unwrap();
let from = from.parse().unwrap();
server.handle_client_input(&input, from, *now).unwrap();
server.handle_client_input(&input, from, *now);
}
Input::Time(now) => {
server.handle_deadline_reached(*now);
@@ -173,13 +170,30 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) {
}
for expected_output in *output {
let actual_output = server
.next_command()
.unwrap_or_else(|| panic!("no commands produced but expected {expected_output:?}"));
let Some(actual_output) = server.next_command() else {
let msg = match expected_output {
Output::SendMessage((recipient, bytes)) => format!("to send message {:?} to {recipient}", parse_hex_message(bytes)),
Output::Forward((ip, data, port)) => format!("forward '{data}' to {ip} on port {port}"),
Output::Wake(instant) => format!("to be woken at {instant:?}"),
Output::CreateAllocation(port) => format!("to create allocation on port {port}"),
Output::ExpireAllocation(port) => format!("to free allocation on port {port}"),
};
panic!("No commands produced but expected {msg}");
};
match (expected_output, actual_output) {
(Output::SendMessage((to, bytes)), Command::SendMessage { payload, recipient }) => {
assert_eq!(*bytes, hex::encode(payload));
let expected_bytes = hex::decode(bytes).unwrap();
if expected_bytes != payload {
let expected_message =
format!("{:?}", parse_message(expected_bytes.as_ref()));
let actual_message = format!("{:?}", parse_message(payload.as_ref()));
difference::assert_diff!(&expected_message, &actual_message, "\n", 0);
}
assert_eq!(recipient, to.parse().unwrap());
}
(
@@ -240,6 +254,15 @@ where
hex::encode(MessageEncoder::new().encode_into_bytes(message).unwrap())
}
fn parse_hex_message(message: &str) -> DecodedMessage<Attribute> {
let message = hex::decode(message).unwrap();
MessageDecoder::new().decode_from_bytes(&message).unwrap()
}
fn parse_message(message: &[u8]) -> DecodedMessage<Attribute> {
MessageDecoder::new().decode_from_bytes(message).unwrap()
}
enum Input {
Client(Ip, Bytes, Instant),
Peer(Ip, Bytes, u16),