diff --git a/rust/relay/server/src/server.rs b/rust/relay/server/src/server.rs index 5f33cb9d2..1e11a9ffc 100644 --- a/rust/relay/server/src/server.rs +++ b/rust/relay/server/src/server.rs @@ -52,9 +52,6 @@ use uuid::Uuid; /// Additionally, we assume to have complete ownership over the port range `lowest_port` - `highest_port`. #[derive(Debug)] pub struct Server { - decoder: client_message::Decoder, - encoder: auth::MessageEncoder, - public_address: IpStack, /// All client allocations, indexed by client's socket address. @@ -195,8 +192,6 @@ where .init(); Self { - decoder: Default::default(), - encoder: Default::default(), public_address: public_address.into(), allocations: Default::default(), clients_by_allocation: Default::default(), @@ -272,7 +267,7 @@ where ) -> Option<(AllocationPort, PeerSocket)> { tracing::trace!(target: "wire", num_bytes = %bytes.len()); - match self.decoder.decode(bytes) { + match client_message::decode(bytes) { Ok(Ok(message)) => { return self.handle_client_message(message, sender, now); } @@ -988,7 +983,7 @@ where let error_code = message.get_attribute::().map(|e| e.code()); tracing::trace!(target: "relay", method = %message.method(), class = %message.class(), "Sending message"); - let Ok(bytes) = self.encoder.encode_into_bytes(message) else { + let Ok(bytes) = auth::MessageEncoder::default().encode_into_bytes(message) else { debug_assert!(false, "Encoding should never fail"); return; }; diff --git a/rust/relay/server/src/server/client_message.rs b/rust/relay/server/src/server/client_message.rs index e645d4c62..c07d74663 100644 --- a/rust/relay/server/src/server/client_message.rs +++ b/rust/relay/server/src/server/client_message.rs @@ -28,63 +28,51 @@ const MAX_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); /// See . const DEFAULT_ALLOCATION_LIFETIME: Duration = Duration::from_secs(600); -#[derive(Default, Debug)] -pub struct Decoder { - stun_message_decoder: stun_codec::MessageDecoder, -} +pub fn decode(input: &[u8]) -> Result>, Error> { + let mut decoder = stun_codec::MessageDecoder::default(); -impl Decoder { - pub fn decode<'a>( - &mut self, - input: &'a [u8], - ) -> Result, Message>, Error> { - // De-multiplex as per . - match input.first() { - Some(0..=3) => { - let message = match self.stun_message_decoder.decode_from_bytes(input)? { - Ok(message) => message, - Err(broken_message) => { - let method = broken_message.method(); - let transaction_id = broken_message.transaction_id(); - let error = broken_message.error().clone(); + // De-multiplex as per . + match input.first() { + Some(0..=3) => { + let message = match decoder.decode_from_bytes(input)? { + Ok(message) => message, + Err(broken_message) => { + let method = broken_message.method(); + let transaction_id = broken_message.transaction_id(); + let error = broken_message.error().clone(); - tracing::debug!(transaction_id = ?transaction_id, %method, %error, "Failed to decode attributes of message"); + tracing::debug!(transaction_id = ?transaction_id, %method, %error, "Failed to decode attributes of message"); - let error_code = ErrorCode::from(error); + let error_code = ErrorCode::from(error); - return Ok(Err(error_response(method, transaction_id, error_code))); - } - }; - - use MessageClass::*; - match (message.method(), message.class()) { - (BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))), - (ALLOCATE, Request) => { - Ok(Allocate::parse(&message).map(ClientMessage::Allocate)) - } - (REFRESH, Request) => Ok(Ok(ClientMessage::Refresh(Refresh::parse(&message)))), - (CHANNEL_BIND, Request) => { - Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind)) - } - (CREATE_PERMISSION, Request) => Ok(Ok(ClientMessage::CreatePermission( - CreatePermission::parse(&message), - ))), - (_, Request) => Ok(Err(bad_request(&message))), - (method, class) => { - Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new( - io::ErrorKind::Unsupported, - format!( - "handling method {} and {class:?} is not implemented", - method.as_u16() - ), - )))) - } + return Ok(Err(error_response(method, transaction_id, error_code))); } + }; + + use MessageClass::*; + match (message.method(), message.class()) { + (BINDING, Request) => Ok(Ok(ClientMessage::Binding(Binding::parse(&message)))), + (ALLOCATE, Request) => Ok(Allocate::parse(&message).map(ClientMessage::Allocate)), + (REFRESH, Request) => Ok(Ok(ClientMessage::Refresh(Refresh::parse(&message)))), + (CHANNEL_BIND, Request) => { + Ok(ChannelBind::parse(&message).map(ClientMessage::ChannelBind)) + } + (CREATE_PERMISSION, Request) => Ok(Ok(ClientMessage::CreatePermission( + CreatePermission::parse(&message), + ))), + (_, Request) => Ok(Err(bad_request(&message))), + (method, class) => Err(Error::DecodeStun(bytecodec::Error::from(io::Error::new( + io::ErrorKind::Unsupported, + format!( + "handling method {} and {class:?} is not implemented", + method.as_u16() + ), + )))), } - Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))), - Some(other) => Err(Error::UnknownMessageType(*other)), - None => Err(Error::Eof), } + Some(64..=79) => Ok(Ok(ClientMessage::ChannelData(ChannelData::parse(input)?))), + Some(other) => Err(Error::UnknownMessageType(*other)), + None => Err(Error::Eof), } }