From c52d88f4219b777edb093c339aea1803ae53bab6 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 19 Apr 2025 01:12:46 +1000 Subject: [PATCH] fix(relay): stateless encoding/decoding (#8810) The STUN message encoder & decoder from `stun_codec` are stateful operations. However, they only operate on one datagram at the time. If encoding or decoding fails, their internal state is corrupted and must be discarded. At present, this doesn't happen which leads to further failures down the line because new datagrams coming in cannot be correctly decoded. To fix this, we scope the stateful nature of these encoders and decoders to their respective functions. Resolves: #8808 --- rust/relay/server/src/server.rs | 9 +- .../relay/server/src/server/client_message.rs | 88 ++++++++----------- 2 files changed, 40 insertions(+), 57 deletions(-) diff --git a/rust/relay/server/src/server.rs b/rust/relay/server/src/server.rs index 5f33cb9d2..1e11a9ffc 100644 --- a/rust/relay/server/src/server.rs +++ b/rust/relay/server/src/server.rs @@ -52,9 +52,6 @@ use uuid::Uuid; /// Additionally, we assume to have complete ownership over the port range `lowest_port` - `highest_port`. #[derive(Debug)] pub struct Server { - decoder: client_message::Decoder, - encoder: auth::MessageEncoder, - public_address: IpStack, /// All client allocations, indexed by client's socket address. @@ -195,8 +192,6 @@ where .init(); Self { - decoder: Default::default(), - encoder: Default::default(), public_address: public_address.into(), allocations: Default::default(), clients_by_allocation: Default::default(), @@ -272,7 +267,7 @@ where ) -> Option<(AllocationPort, PeerSocket)> { tracing::trace!(target: "wire", num_bytes = %bytes.len()); - match self.decoder.decode(bytes) { + match client_message::decode(bytes) { Ok(Ok(message)) => { return self.handle_client_message(message, sender, now); } @@ -988,7 +983,7 @@ where let error_code = message.get_attribute::().map(|e| e.code()); tracing::trace!(target: "relay", method = %message.method(), class = %message.class(), "Sending message"); - let Ok(bytes) = self.encoder.encode_into_bytes(message) else { + let Ok(bytes) = auth::MessageEncoder::default().encode_into_bytes(message) else { debug_assert!(false, "Encoding should never fail"); return; }; diff --git a/rust/relay/server/src/server/client_message.rs b/rust/relay/server/src/server/client_message.rs index e645d4c62..c07d74663 100644 --- a/rust/relay/server/src/server/client_message.rs +++ b/rust/relay/server/src/server/client_message.rs @@ -28,63 +28,51 @@ const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); /// See . const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); -#[derive(Default, Debug)] -pub struct Decoder { - stun_message_decoder: stun_codec::MessageDecoder, -} +pub fn decode(input: &[u8]) -> Result>, Error> { + let mut decoder = stun_codec::MessageDecoder::default(); -impl Decoder { - pub fn decode<'a>( - &mut self, - input: &'a [u8], - ) -> Result, Message>, Error> { - // De-multiplex as per . - match input.first() { - Some(0..=3) => { - let message = match self.stun_message_decoder.decode_from_bytes(input)? { - Ok(message) => message, - Err(broken_message) => { - let method = broken_message.method(); - let transaction_id = broken_message.transaction_id(); - let error = broken_message.error().clone(); + // De-multiplex as per . + match input.first() { + Some(0..=3) => { + let message = match decoder.decode_from_bytes(input)? { + Ok(message) => message, + Err(broken_message) => { + let method = broken_message.method(); + let transaction_id = broken_message.transaction_id(); + let error = broken_message.error().clone(); - tracing::debug!(transaction_id = ?transaction_id, %method, %error, "Failed to decode attributes of message"); + tracing::debug!(transaction_id = ?transaction_id, %method, %error, "Failed to decode attributes of message"); - let error_code = ErrorCode::from(error); + let error_code = ErrorCode::from(error); - return Ok(Err(error_response(method, transaction_id, error_code))); - } - }; - - use MessageClass::*; - match (message.method(), message.class()) { - (BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))), - (ALLOCATE, Request) => { - Ok(Allocate::parse(&message).map(ClientMessage::Allocate)) - } - (REFRESH, Request) => Ok(Ok(ClientMessage::Refresh(Refresh::parse(&message)))), - (CHANNEL_BIND, Request) => { - Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind)) - } - (CREATE_PERMISSION, Request) => Ok(Ok(ClientMessage::CreatePermission( - CreatePermission::parse(&message), - ))), - (_, Request) => Ok(Err(bad_request(&message))), - (method, class) => { - Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new( - io::ErrorKind::Unsupported, - format!( - "handling method {} and {class:?} is not implemented", - method.as_u16() - ), - )))) - } + return Ok(Err(error_response(method, transaction_id, error_code))); } + }; + + use MessageClass::*; + match (message.method(), message.class()) { + (BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))), + (ALLOCATE, Request) => Ok(Allocate::parse(&message).map(ClientMessage::Allocate)), + (REFRESH, Request) => Ok(Ok(ClientMessage::Refresh(Refresh::parse(&message)))), + (CHANNEL_BIND, Request) => { + Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind)) + } + (CREATE_PERMISSION, Request) => Ok(Ok(ClientMessage::CreatePermission( + CreatePermission::parse(&message), + ))), + (_, Request) => Ok(Err(bad_request(&message))), + (method, class) => Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new( + io::ErrorKind::Unsupported, + format!( + "handling method {} and {class:?} is not implemented", + method.as_u16() + ), + )))), } - Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))), - Some(other) => Err(Error::UnknownMessageType(*other)), - None => Err(Error::Eof), } + Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))), + Some(other) => Err(Error::UnknownMessageType(*other)), + None => Err(Error::Eof), } }