diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8bf2d2c95..fd1a49da2 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -603,6 +603,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "difference" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198" + [[package]] name = "digest" version = "0.9.0" @@ -1525,6 +1531,7 @@ dependencies = [ "anyhow", "bytecodec", "bytes", + "difference", "env_logger", "futures", "hex", diff --git a/rust/relay/Cargo.toml b/rust/relay/Cargo.toml index 9f1e5c1e9..a10f24641 100644 --- a/rust/relay/Cargo.toml +++ b/rust/relay/Cargo.toml @@ -20,3 +20,4 @@ bytes = "1.4.0" [dev-dependencies] webrtc = "0.7.2" redis = { version = "0.23.0", default-features = false, features = ["tokio-comp"] } +difference = "2.0.0" diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 3aa91683e..a7f84124f 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -190,7 +190,7 @@ impl Eventloop { // Priority 6: Accept new allocations / answer STUN requests etc if let Poll::Ready((buffer, sender)) = self.ip4_socket.poll_recv(cx)? { self.server - .handle_client_input(buffer.filled(), sender, Instant::now())?; + .handle_client_input(buffer.filled(), sender, Instant::now()); continue; // Handle potentially new commands. } diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 54307f68a..774c5ed85 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -1,10 +1,15 @@ mod channel_data; +mod client_message; use crate::rfc8656::PeerAddressFamilyMismatch; +use crate::server::channel_data::ChannelData; +use crate::server::client_message::{ + Allocate, Binding, ChannelBind, ClientMessage, CreatePermission, Refresh, +}; use crate::stun_codec_ext::{MessageClassExt, MethodExt}; use crate::TimeEvents; use anyhow::Result; -use bytecodec::{DecodeExt, EncodeExt}; +use bytecodec::EncodeExt; use core::fmt; use rand::rngs::mock::StepRng; use rand::rngs::ThreadRng; @@ -23,7 +28,7 @@ use stun_codec::rfc5766::attributes::{ }; use stun_codec::rfc5766::errors::{AllocationMismatch, InsufficientCapacity}; use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH}; -use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method, TransactionId}; +use stun_codec::{Message, MessageClass, MessageEncoder, Method, TransactionId}; /// A sans-IO STUN & TURN server. /// @@ -33,7 +38,7 @@ use stun_codec::{Message, MessageClass, MessageDecoder, MessageEncoder, Method, /// /// Additionally, we assume to have complete ownership over the port range `LOWEST_PORT` - `HIGHEST_PORT`. pub struct Server { - decoder: MessageDecoder, + decoder: client_message::Decoder, encoder: MessageEncoder, public_ip4_address: Ipv4Addr, @@ -108,14 +113,6 @@ const HIGHEST_PORT: u16 = 65535; /// The maximum number of ports available for allocation. const MAX_AVAILABLE_PORTS: u16 = HIGHEST_PORT - LOWEST_PORT; -/// The maximum lifetime of an allocation. -const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); - -/// The default lifetime of an allocation. -/// -/// See . -const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); - /// The duration of a channel binding. /// /// See . @@ -149,71 +146,58 @@ where /// Process the bytes received from a client. /// /// After calling this method, you should call [`Server::next_command`] until it returns `None`. - pub fn handle_client_input( - &mut self, - bytes: &[u8], - sender: SocketAddr, - now: Instant, - ) -> Result<()> { + pub fn handle_client_input(&mut self, bytes: &[u8], sender: SocketAddr, now: Instant) { if tracing::enabled!(target: "wire", tracing::Level::TRACE) { let hex_bytes = hex::encode(bytes); tracing::trace!(target: "wire", r#"Input::client("{sender}","{hex_bytes}")"#); } - // De-multiplex as per . - match bytes.first() { - Some(0..=3) => { - let Ok(message) = self.decoder.decode_from_bytes(bytes)? else { - tracing::warn!(target: "relay", "received broken STUN message from {sender}"); - return Ok(()); - }; - - tracing::trace!(target: "relay", "Received {} {} from {sender}", message.method().as_str(), message.class().as_str()); - - self.dispatch_stun_message(message, sender, now, |server, message, sender, now| { - use MessageClass::*; - match (message.method(), message.class()) { - (BINDING, Request) => { - server.handle_binding_request(message, sender); - Ok(()) - } - (ALLOCATE, Request) => server.handle_allocate_request(message, sender, now), - (REFRESH, Request) => server.handle_refresh_request(message, sender, now), - (CHANNEL_BIND, Request) => { - server.handle_channel_bind_request(message, sender, now) - } - (CREATE_PERMISSION, Request) => { - server.handle_create_permission_request(message, sender, now) - } - (_, Indication) => { - tracing::trace!(target: "relay", "Indications are not yet implemented"); - - Err(ErrorCode::from(BadRequest)) - } - _ => Err(ErrorCode::from(BadRequest)), - } - }); + let result = match self.decoder.decode(bytes) { + Ok(Ok(ClientMessage::Allocate(request))) => { + self.handle_allocate_request(request, sender, now) + } + Ok(Ok(ClientMessage::Refresh(request))) => { + self.handle_refresh_request(request, sender, now) + } + Ok(Ok(ClientMessage::ChannelBind(request))) => { + self.handle_channel_bind_request(request, sender, now) + } + Ok(Ok(ClientMessage::CreatePermission(request))) => { + self.handle_create_permission_request(request, sender, now) + } + Ok(Ok(ClientMessage::Binding(request))) => { + self.handle_binding_request(request, sender); + return; + } + Ok(Ok(ClientMessage::ChannelData(msg))) => { + self.handle_channel_data_message(msg, sender, now); + return; } - Some(64..=79) => { - let (channel, data) = match channel_data::parse(bytes) { - Ok(v) => v, - Err(e) => { - tracing::debug!( - target: "relay", - "failed to parse channel data message: {e:#}" - ); - return Ok(()); - } - }; - self.handle_channel_data_message(channel, data, sender, now); - } - _ => { - tracing::trace!(target: "relay", "Received unknown message from {sender}"); - } + // Could parse the bytes but message was semantically invalid (like missing attribute). + Ok(Err(error_code)) => Err(error_code), + + // Parsing the bytes failed. + Err(client_message::Error::BadChannelData(_)) => return, + Err(client_message::Error::DecodeStun(_)) => return, + Err(client_message::Error::UnknownMessageType(_)) => return, + Err(client_message::Error::Eof) => return, + }; + + let Err(mut error_response) = result else { + return; + }; + + // In case of a 401 response, attach a realm and nonce. + if error_response + .get_attribute::() + .map_or(false, |error| error == &ErrorCode::from(Unauthorized)) + { + error_response.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into()); + error_response.add_attribute(Realm::new("firezone".to_owned()).unwrap().into()); } - Ok(()) + self.send_message(error_response, sender); } /// Process the bytes received from an allocation. @@ -264,7 +248,7 @@ where ); let recipient = *client; - let data = channel_data::make(*channel_number, bytes); + let data = ChannelData::new(*channel_number, bytes).to_bytes(); if tracing::enabled!(target: "wire", tracing::Level::TRACE) { let hex_bytes = hex::encode(&data); @@ -321,27 +305,7 @@ where self.pending_commands.pop_front() } - fn dispatch_stun_message( - &mut self, - message: Message, - sender: SocketAddr, - now: Instant, - handler: impl Fn(&mut Self, Message, SocketAddr, Instant) -> Result<(), ErrorCode>, - ) { - let transaction_id = message.transaction_id(); - let method = message.method(); - - if let Err(e) = handler(self, message, sender, now) { - if e.code() == Unauthorized::CODEPOINT { - self.send_message(unauthorized(transaction_id, method), sender); - return; - } - - self.send_message(error_response(transaction_id, method, e), sender); - } - } - - fn handle_binding_request(&mut self, message: Message, sender: SocketAddr) { + pub fn handle_binding_request(&mut self, message: Binding, sender: SocketAddr) { let mut message = Message::new( MessageClass::SuccessResponse, BINDING, @@ -355,35 +319,27 @@ where /// Handle a TURN allocate request. /// /// See for details. - fn handle_allocate_request( + pub fn handle_allocate_request( &mut self, - message: Message, + request: Allocate, sender: SocketAddr, now: Instant, - ) -> Result<(), ErrorCode> { - let _ = message - .get_attribute::() - .ok_or(Unauthorized)?; + ) -> Result<(), Message> { + // TODO: Check validity of message integrity here? if self.allocations.contains_key(&sender) { - return Err(AllocationMismatch.into()); + return Err(error_response(AllocationMismatch, &request)); } if self.allocations_by_port.len() == MAX_AVAILABLE_PORTS as usize { - return Err(InsufficientCapacity.into()); + return Err(error_response(InsufficientCapacity, &request)); } - let requested_transport = message - .get_attribute::() - .ok_or(BadRequest)?; - - if requested_transport.protocol() != UDP_TRANSPORT { - return Err(BadRequest.into()); + if request.requested_transport().protocol() != UDP_TRANSPORT { + return Err(error_response(BadRequest, &request)); } - let requested_lifetime = message.get_attribute::(); - - let effective_lifetime = compute_effective_lifetime(requested_lifetime); + let effective_lifetime = request.effective_lifetime(); // TODO: Do we need to handle DONT-FRAGMENT? // TODO: Do we need to handle EVEN/ODD-PORT? @@ -393,7 +349,7 @@ where let mut message = Message::new( MessageClass::SuccessResponse, ALLOCATE, - message.transaction_id(), + request.transaction_id(), ); let ip4_relay_address = self.public_relay_address_for_port(allocation.port); @@ -430,24 +386,21 @@ where /// Handle a TURN refresh request. /// /// See for details. - fn handle_refresh_request( + pub fn handle_refresh_request( &mut self, - message: Message, + request: Refresh, sender: SocketAddr, now: Instant, - ) -> Result<(), ErrorCode> { - let _ = message - .get_attribute::() - .ok_or(Unauthorized)?; + ) -> Result<(), Message> { + // TODO: Check validity of message integrity here? // TODO: Verify that this is the correct error code. let allocation = self .allocations .get_mut(&sender) - .ok_or(ErrorCode::from(AllocationMismatch))?; + .ok_or(error_response(AllocationMismatch, &request))?; - let requested_lifetime = message.get_attribute::(); - let effective_lifetime = compute_effective_lifetime(requested_lifetime); + let effective_lifetime = request.effective_lifetime(); if effective_lifetime.lifetime().is_zero() { let port = allocation.port; @@ -457,7 +410,7 @@ where self.allocations.remove(&sender); self.allocations_by_port.remove(&port); self.send_message( - refresh_success_response(effective_lifetime, message.transaction_id()), + refresh_success_response(effective_lifetime, request.transaction_id()), sender, ); @@ -486,7 +439,7 @@ where deadline: wake_deadline, }); self.send_message( - refresh_success_response(effective_lifetime, message.transaction_id()), + refresh_success_response(effective_lifetime, request.transaction_id()), sender, ); @@ -496,49 +449,40 @@ where /// Handle a TURN channel bind request. /// /// See for details. - fn handle_channel_bind_request( + pub fn handle_channel_bind_request( &mut self, - message: Message, + request: ChannelBind, sender: SocketAddr, now: Instant, - ) -> Result<(), ErrorCode> { - let _ = message - .get_attribute::() - .ok_or(Unauthorized)?; + ) -> Result<(), Message> { + // TODO: Check validity of message integrity here? let allocation = self .allocations .get_mut(&sender) - .ok_or(ErrorCode::from(AllocationMismatch))?; + .ok_or(error_response(AllocationMismatch, &request))?; - let requested_channel = message - .get_attribute::() - .ok_or(ErrorCode::from(BadRequest))? - .value(); - - let peer_address = message - .get_attribute::() - .ok_or(ErrorCode::from(BadRequest))? - .address(); + let requested_channel = request.channel_number().value(); + let peer_address = request.xor_peer_address().address(); // Note: `channel_number` is enforced to be in the correct range. // Check that our allocation can handle the requested peer addr. if !allocation.can_relay_to(peer_address) { - return Err(ErrorCode::from(PeerAddressFamilyMismatch)); + return Err(error_response(PeerAddressFamilyMismatch, &request)); } // Ensure the same address isn't already bound to a different channel. if let Some(number) = self.channel_numbers_by_peer.get(&peer_address) { if number != &requested_channel { - return Err(ErrorCode::from(BadRequest)); + return Err(error_response(BadRequest, &request)); } } // Ensure the channel is not already bound to a different address. if let Some(channel) = self.channels_by_number.get_mut(&requested_channel) { if channel.peer_address != peer_address { - return Err(ErrorCode::from(BadRequest)); + return Err(error_response(BadRequest, &request)); } // Binding requests for existing channels act as a refresh for the binding. @@ -552,7 +496,7 @@ where TimedAction::UnbindChannel(requested_channel), ); self.send_message( - channel_bind_success_response(message.transaction_id()), + channel_bind_success_response(request.transaction_id()), sender, ); @@ -567,7 +511,7 @@ where let allocation_id = allocation.id; self.create_channel_binding(requested_channel, peer_address, allocation_id, now); self.send_message( - channel_bind_success_response(message.transaction_id()), + channel_bind_success_response(request.transaction_id()), sender, ); @@ -582,15 +526,14 @@ where /// /// This TURN server implementation does not support relaying data other than through channels. /// Thus, creating a permission is a no-op that always succeeds. - fn handle_create_permission_request( + pub fn handle_create_permission_request( &mut self, - message: Message, + message: CreatePermission, sender: SocketAddr, _: Instant, - ) -> Result<(), ErrorCode> { - let _ = message - .get_attribute::() - .ok_or(Unauthorized)?; + ) -> Result<(), Message> { + // TODO: Check validity of message integrity here? + self.send_message( create_permission_success_response(message.transaction_id()), sender, @@ -599,13 +542,15 @@ where Ok(()) } - fn handle_channel_data_message( + pub fn handle_channel_data_message( &mut self, - channel_number: u16, - data: &[u8], + message: ChannelData, sender: SocketAddr, _: Instant, ) { + let channel_number = message.channel(); + let data = message.data(); + let Some(channel) = self.channels_by_number.get(&channel_number) else { tracing::debug!(target: "relay", "Channel {channel_number} does not exist, refusing to forward data"); return; @@ -834,37 +779,45 @@ enum TimedAction { DeleteChannel(u16), } -/// Computes the effective lifetime of an allocation. -fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime { - let Some(requested) = requested_lifetime else { - return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap(); - }; - - let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME); - - Lifetime::new(effective_lifetime).unwrap() -} - fn error_response( - transaction_id: TransactionId, - method: Method, - error_code: ErrorCode, + error_code: impl Into, + request: &impl StunRequest, ) -> Message { - let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id); - message.add_attribute(error_code.into()); + let mut message = Message::new( + MessageClass::ErrorResponse, + request.method(), + request.transaction_id(), + ); + message.add_attribute(Attribute::from(error_code.into())); message } -fn unauthorized(transaction_id: TransactionId, method: Method) -> Message { - let mut message = Message::new(MessageClass::ErrorResponse, method, transaction_id); - message.add_attribute(ErrorCode::from(Unauthorized).into()); - message.add_attribute(Nonce::new("foobar".to_owned()).unwrap().into()); - message.add_attribute(Realm::new("firezone".to_owned()).unwrap().into()); - - message +/// Private helper trait to make [`error_response`] more ergonomic to use. +trait StunRequest { + fn transaction_id(&self) -> TransactionId; + fn method(&self) -> Method; } +macro_rules! impl_stun_request_for { + ($t:ty, $m:expr) => { + impl StunRequest for $t { + fn transaction_id(&self) -> TransactionId { + self.transaction_id() + } + + fn method(&self) -> Method { + $m + } + } + }; +} + +impl_stun_request_for!(Allocate, ALLOCATE); +impl_stun_request_for!(ChannelBind, CHANNEL_BIND); +impl_stun_request_for!(CreatePermission, CREATE_PERMISSION); +impl_stun_request_for!(Refresh, REFRESH); + // Define an enum of all attributes that we care about for our server. stun_codec::define_attribute_enums!( Attribute, @@ -884,17 +837,3 @@ stun_codec::define_attribute_enums!( Username ] ); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn requested_lifetime_is_capped_at_max_lifetime() { - let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap(); - - let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime)); - - assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME) - } -} diff --git a/rust/relay/src/server/channel_data.rs b/rust/relay/src/server/channel_data.rs index c43efa41f..d3f410616 100644 --- a/rust/relay/src/server/channel_data.rs +++ b/rust/relay/src/server/channel_data.rs @@ -1,27 +1,61 @@ -use anyhow::{bail, Result}; use bytes::{BufMut, BytesMut}; +use std::io; -pub(crate) fn make(channel: u16, data: &[u8]) -> Vec { - let mut message = BytesMut::with_capacity(2 + 2 + data.len()); - - message.put_u16(channel); - message.put_u16(data.len() as u16); - message.put_slice(data); - - message.freeze().to_vec() +pub struct ChannelData<'a> { + channel: u16, + data: &'a [u8], } -pub(crate) fn parse(data: &[u8]) -> Result<(u16, &[u8])> { - if data.len() < 4 { - bail!("must have at least 4 bytes for channel data message") +impl<'a> ChannelData<'a> { + pub fn parse(data: &'a [u8]) -> Result { + if data.len() < 4 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "channel data messages are at least 4 bytes long", + )); + } + + let channel_number = u16::from_be_bytes([data[0], data[1]]); + let length = u16::from_be_bytes([data[2], data[3]]); + + let actual_payload_length = data.len() - 4; + + if actual_payload_length != length as usize { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "channel data message specified {length} bytes but got {actual_payload_length}" + ), + )); + } + + Ok(ChannelData { + channel: channel_number, + data: &data[4..], + }) } - let channel_number = u16::from_be_bytes([data[0], data[1]]); - let length = u16::from_be_bytes([data[2], data[3]]); + pub fn new(channel: u16, data: &'a [u8]) -> Self { + ChannelData { channel, data } + } - anyhow::ensure!((data.len() - 4) == length as usize); + pub fn to_bytes(&self) -> Vec { + let mut message = BytesMut::with_capacity(2 + 2 + self.data.len()); - Ok((channel_number, &data[4..])) + message.put_u16(self.channel); + message.put_u16(self.data.len() as u16); + message.put_slice(self.data); + + message.freeze().into() + } + + pub fn channel(&self) -> u16 { + self.channel + } + + pub fn data(&self) -> &[u8] { + self.data + } } // TODO: tests diff --git a/rust/relay/src/server/client_message.rs b/rust/relay/src/server/client_message.rs new file mode 100644 index 000000000..243f1784e --- /dev/null +++ b/rust/relay/src/server/client_message.rs @@ -0,0 +1,400 @@ +use crate::server::channel_data::ChannelData; +use crate::Attribute; +use bytecodec::DecodeExt; +use std::io; +use std::time::Duration; +use stun_codec::rfc5389::attributes::{ErrorCode, MessageIntegrity, Username}; +use stun_codec::rfc5389::errors::{BadRequest, Unauthorized}; +use stun_codec::rfc5389::methods::BINDING; +use stun_codec::rfc5766::attributes::{ + ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress, +}; +use stun_codec::rfc5766::methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH}; +use stun_codec::{BrokenMessage, Message, MessageClass, TransactionId}; + +/// The maximum lifetime of an allocation. +const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); + +/// The default lifetime of an allocation. +/// +/// See . +const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); + +#[derive(Default)] +pub struct Decoder { + stun_message_decoder: stun_codec::MessageDecoder, +} + +impl Decoder { + pub fn decode<'a>( + &mut self, + input: &'a [u8], + ) -> Result, Message>, Error> { + // De-multiplex as per . + match input.first() { + Some(0..=3) => { + let message = self.stun_message_decoder.decode_from_bytes(input)??; + + use MessageClass::*; + match (message.method(), message.class()) { + (BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))), + (ALLOCATE, Request) => { + Ok(Allocate::parse(&message).map(ClientMessage::Allocate)) + } + (REFRESH, Request) => Ok(Refresh::parse(&message).map(ClientMessage::Refresh)), + (CHANNEL_BIND, Request) => { + Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind)) + } + (CREATE_PERMISSION, Request) => { + Ok(CreatePermission::parse(&message).map(ClientMessage::CreatePermission)) + } + (_, Request) => Ok(Err(bad_request(&message))), + (method, class) => { + Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new( + io::ErrorKind::Unsupported, + format!( + "handling method {} and {class:?} is not implemented", + method.as_u16() + ), + )))) + } + } + } + Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))), + Some(other) => Err(Error::UnknownMessageType(*other)), + None => Err(Error::Eof), + } + } +} + +pub enum ClientMessage<'a> { + ChannelData(ChannelData<'a>), + Binding(Binding), + Allocate(Allocate), + Refresh(Refresh), + ChannelBind(ChannelBind), + CreatePermission(CreatePermission), +} + +pub struct Binding { + transaction_id: TransactionId, +} + +impl Binding { + pub fn new(transaction_id: TransactionId) -> Self { + Self { transaction_id } + } + + pub fn parse(message: &Message) -> Self { + let transaction_id = message.transaction_id(); + + Binding { transaction_id } + } + + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } +} + +pub struct Allocate { + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + requested_transport: RequestedTransport, + lifetime: Option, +} + +impl Allocate { + pub fn new( + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + requested_transport: RequestedTransport, + lifetime: Option, + ) -> Self { + Self { + transaction_id, + message_integrity, + requested_transport, + lifetime, + } + } + + pub fn parse(message: &Message) -> Result> { + let transaction_id = message.transaction_id(); + let message_integrity = message + .get_attribute::() + .ok_or(unauthorized(message))? + .clone(); + let requested_transport = message + .get_attribute::() + .ok_or(bad_request(message))? + .clone(); + let lifetime = message.get_attribute::().cloned(); + + Ok(Allocate { + transaction_id, + message_integrity, + requested_transport, + lifetime, + }) + } + + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } + + pub fn message_integrity(&self) -> &MessageIntegrity { + &self.message_integrity + } + + pub fn requested_transport(&self) -> &RequestedTransport { + &self.requested_transport + } + + pub fn effective_lifetime(&self) -> Lifetime { + compute_effective_lifetime(self.lifetime.as_ref()) + } +} + +pub struct Refresh { + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + lifetime: Option, +} + +impl Refresh { + pub fn new( + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + lifetime: Option, + ) -> Self { + Self { + transaction_id, + message_integrity, + lifetime, + } + } + + pub fn parse(message: &Message) -> Result> { + let transaction_id = message.transaction_id(); + let message_integrity = message + .get_attribute::() + .ok_or(unauthorized(message))? + .clone(); + let lifetime = message.get_attribute::().cloned(); + + Ok(Refresh { + transaction_id, + message_integrity, + lifetime, + }) + } + + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } + + pub fn message_integrity(&self) -> &MessageIntegrity { + &self.message_integrity + } + + pub fn effective_lifetime(&self) -> Lifetime { + compute_effective_lifetime(self.lifetime.as_ref()) + } +} + +pub struct ChannelBind { + transaction_id: TransactionId, + channel_number: ChannelNumber, + message_integrity: MessageIntegrity, + xor_peer_address: XorPeerAddress, + username: Username, +} + +impl ChannelBind { + pub fn new( + transaction_id: TransactionId, + channel_number: ChannelNumber, + message_integrity: MessageIntegrity, + username: Username, + xor_peer_address: XorPeerAddress, + ) -> Self { + Self { + transaction_id, + channel_number, + message_integrity, + xor_peer_address, + username, + } + } + + pub fn parse(message: &Message) -> Result> { + let transaction_id = message.transaction_id(); + let channel_number = message + .get_attribute::() + .copied() + .ok_or(bad_request(message))?; + let message_integrity = message + .get_attribute::() + .ok_or(unauthorized(message))? + .clone(); + let username = message + .get_attribute::() + .ok_or(bad_request(message))? + .clone(); + let xor_peer_address = message + .get_attribute::() + .ok_or(bad_request(message))? + .clone(); + + Ok(ChannelBind { + transaction_id, + channel_number, + message_integrity, + xor_peer_address, + username, + }) + } + + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } + + pub fn channel_number(&self) -> ChannelNumber { + self.channel_number + } + + pub fn message_integrity(&self) -> &MessageIntegrity { + &self.message_integrity + } + + pub fn xor_peer_address(&self) -> &XorPeerAddress { + &self.xor_peer_address + } + + pub fn username(&self) -> &Username { + &self.username + } +} + +pub struct CreatePermission { + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + username: Username, +} + +impl CreatePermission { + pub fn new( + transaction_id: TransactionId, + message_integrity: MessageIntegrity, + username: Username, + ) -> Self { + Self { + transaction_id, + message_integrity, + username, + } + } + + pub fn parse(message: &Message) -> Result> { + let transaction_id = message.transaction_id(); + let message_integrity = message + .get_attribute::() + .ok_or(unauthorized(message))? + .clone(); + let username = message + .get_attribute::() + .ok_or(bad_request(message))? + .clone(); + + Ok(CreatePermission { + transaction_id, + message_integrity, + username, + }) + } + + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } + + pub fn message_integrity(&self) -> &MessageIntegrity { + &self.message_integrity + } + + pub fn username(&self) -> &Username { + &self.username + } +} + +/// Computes the effective lifetime of an allocation. +fn compute_effective_lifetime(requested_lifetime: Option<&Lifetime>) -> Lifetime { + let Some(requested) = requested_lifetime else { + return Lifetime::new(DEFAULT_ALLOCATION_LIFETIME).unwrap(); + }; + + let effective_lifetime = requested.lifetime().min(MAX_ALLOCATION_LIFETIME); + + Lifetime::new(effective_lifetime).unwrap() +} + +fn bad_request(message: &Message) -> Message { + let mut message = Message::new( + MessageClass::ErrorResponse, + message.method(), + message.transaction_id(), + ); + message.add_attribute(ErrorCode::from(BadRequest).into()); + + message +} + +fn unauthorized(message: &Message) -> Message { + let mut message = Message::new( + MessageClass::ErrorResponse, + message.method(), + message.transaction_id(), + ); + message.add_attribute(ErrorCode::from(Unauthorized).into()); + + message +} + +#[derive(Debug)] +pub enum Error { + BadChannelData(io::Error), + DecodeStun(bytecodec::Error), + UnknownMessageType(u8), + Eof, +} + +impl From for Error { + fn from(msg: BrokenMessage) -> Self { + Error::DecodeStun(msg.into()) + } +} + +impl From for Error { + fn from(error: bytecodec::Error) -> Self { + Error::DecodeStun(error) + } +} + +impl From for Error { + fn from(error: io::Error) -> Self { + Error::BadChannelData(error) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn requested_lifetime_is_capped_at_max_lifetime() { + let requested_lifetime = Lifetime::new(Duration::from_secs(10_000_000)).unwrap(); + + let effective_lifetime = compute_effective_lifetime(Some(&requested_lifetime)); + + assert_eq!(effective_lifetime.lifetime(), MAX_ALLOCATION_LIFETIME) + } +} diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index fadfbea72..5666ae7a4 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -1,4 +1,4 @@ -use bytecodec::EncodeExt; +use bytecodec::{DecodeExt, EncodeExt}; use hex_literal::hex; use relay::{AllocationId, Attribute, Command, Server}; use std::collections::HashMap; @@ -6,7 +6,9 @@ use std::time::{Duration, Instant}; use stun_codec::rfc5389::attributes::{MessageIntegrity, Realm, Username}; use stun_codec::rfc5766::attributes::Lifetime; use stun_codec::rfc5766::methods::REFRESH; -use stun_codec::{Message, MessageClass, MessageEncoder, TransactionId}; +use stun_codec::{ + DecodedMessage, Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId, +}; #[test] fn stun_binding_request() { @@ -123,11 +125,6 @@ fn ping_pong_relay() { Output::send_message("127.0.0.1:42677","010800002112a442dc5c115f6b727e25a54b55d3") ]), ( - Input::client("127.0.0.1:42677","001600242112a4421b6fb3ce9cefcd57ef3a9edb00130009484f4c4550554e4348000000001200080001f2405e12a443802800041bfe0967", now), - &[ - Output::send_message("127.0.0.1:42677","011600142112a4421b6fb3ce9cefcd57ef3a9edb0009000f00000400426164205265717565737400") - ]), - ( Input::client("127.0.0.1:42677","000900542112a4420afbde5aaacfc1e9316beae9001200080001f2405e12a443000c000440000000000600047465737400140008666972657a6f6e6500150006666f6f626172000000080014aca01c6cdc1fc5339a309e5bccac3df5c903e33e802800041fe4b79b", now), &[ Output::send_message("127.0.0.1:42677","010900002112a4420afbde5aaacfc1e9316beae9") @@ -159,7 +156,7 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) { let input = hex::decode(data).unwrap(); let from = from.parse().unwrap(); - server.handle_client_input(&input, from, *now).unwrap(); + server.handle_client_input(&input, from, *now); } Input::Time(now) => { server.handle_deadline_reached(*now); @@ -173,13 +170,30 @@ fn run_regression_test(sequence: &[(Input, &[Output])]) { } for expected_output in *output { - let actual_output = server - .next_command() - .unwrap_or_else(|| panic!("no commands produced but expected {expected_output:?}")); + let Some(actual_output) = server.next_command() else { + let msg = match expected_output { + Output::SendMessage((recipient, bytes)) => format!("to send message {:?} to {recipient}", parse_hex_message(bytes)), + Output::Forward((ip, data, port)) => format!("forward '{data}' to {ip} on port {port}"), + Output::Wake(instant) => format!("to be woken at {instant:?}"), + Output::CreateAllocation(port) => format!("to create allocation on port {port}"), + Output::ExpireAllocation(port) => format!("to free allocation on port {port}"), + }; + + panic!("No commands produced but expected {msg}"); + }; match (expected_output, actual_output) { (Output::SendMessage((to, bytes)), Command::SendMessage { payload, recipient }) => { - assert_eq!(*bytes, hex::encode(payload)); + let expected_bytes = hex::decode(bytes).unwrap(); + + if expected_bytes != payload { + let expected_message = + format!("{:?}", parse_message(expected_bytes.as_ref())); + let actual_message = format!("{:?}", parse_message(payload.as_ref())); + + difference::assert_diff!(&expected_message, &actual_message, "\n", 0); + } + assert_eq!(recipient, to.parse().unwrap()); } ( @@ -240,6 +254,15 @@ where hex::encode(MessageEncoder::new().encode_into_bytes(message).unwrap()) } +fn parse_hex_message(message: &str) -> DecodedMessage { + let message = hex::decode(message).unwrap(); + MessageDecoder::new().decode_from_bytes(&message).unwrap() +} + +fn parse_message(message: &[u8]) -> DecodedMessage { + MessageDecoder::new().decode_from_bytes(message).unwrap() +} + enum Input { Client(Ip, Bytes, Instant), Peer(Ip, Bytes, u16),