diff --git a/rust/relay/ebpf-turn-router/src/channel_data.rs b/rust/relay/ebpf-turn-router/src/channel_data.rs index 2a97affdf..ed0072d43 100644 --- a/rust/relay/ebpf-turn-router/src/channel_data.rs +++ b/rust/relay/ebpf-turn-router/src/channel_data.rs @@ -42,6 +42,10 @@ impl<'a> ChannelData<'a> { pub fn number(&self) -> u16 { u16::from_be_bytes(self.inner.number) } + + pub fn length(&self) -> u16 { + u16::from_be_bytes(self.inner.length) + } } #[repr(C)] diff --git a/rust/relay/ebpf-turn-router/src/main.rs b/rust/relay/ebpf-turn-router/src/main.rs index 9f934477e..e2c22a7e6 100644 --- a/rust/relay/ebpf-turn-router/src/main.rs +++ b/rust/relay/ebpf-turn-router/src/main.rs @@ -154,6 +154,8 @@ fn try_handle_ipv4_channel_data_to_udp(ctx: &XdpContext, ipv4: Ip4, udp: Udp) -> port_and_peer.allocation_port(), port_and_peer.peer_port(), new_udp_len, + cd.number(), + cd.length(), ); remove_channel_data_header_ipv4(ctx)?; @@ -173,15 +175,21 @@ fn try_handle_ipv4_udp_to_channel_data(ctx: &XdpContext, ipv4: Ip4, udp: Udp) -> let udp_len = udp.len(); let new_udp_len = udp_len + CdHdr::LEN as u16; + + let channel_number = client_and_channel.channel(); + let channel_data_length = udp_len - UdpHdr::LEN as u16; + udp.update( pseudo_header, 3478, client_and_channel.client_port(), new_udp_len, + channel_number, + channel_data_length, ); - let cd_num = client_and_channel.channel().to_be_bytes(); - let cd_len = (udp_len - UdpHdr::LEN as u16).to_be_bytes(); // The `length` field in the UDP header includes the header itself. For the channel-data field, we only want the length of the payload. + let cd_num = channel_number.to_be_bytes(); + let cd_len = channel_data_length.to_be_bytes(); // The `length` field in the UDP header includes the header itself. For the channel-data field, we only want the length of the payload. let channel_data_header = [cd_num[0], cd_num[1], cd_len[0], cd_len[1]]; @@ -239,15 +247,21 @@ fn try_handle_ipv6_udp_to_channel_data(ctx: &XdpContext, ipv6: Ip6, udp: Udp) -> let udp_len = udp.len(); let new_udp_len = udp_len + CdHdr::LEN as u16; + + let channel_number = client_and_channel.channel(); + let channel_data_length = udp_len - UdpHdr::LEN as u16; + udp.update( pseudo_header, 3478, client_and_channel.client_port(), new_udp_len, + channel_number, + channel_data_length, ); - let cd_num = client_and_channel.channel().to_be_bytes(); - let cd_len = (udp_len - UdpHdr::LEN as u16).to_be_bytes(); // The `length` field in the UDP header includes the header itself. For the channel-data field, we only want the length of the payload. + let cd_num = channel_number.to_be_bytes(); + let cd_len = channel_data_length.to_be_bytes(); // The `length` field in the UDP header includes the header itself. For the channel-data field, we only want the length of the payload. let channel_data_header = [cd_num[0], cd_num[1], cd_len[0], cd_len[1]]; @@ -274,6 +288,8 @@ fn try_handle_ipv6_channel_data_to_udp(ctx: &XdpContext, ipv6: Ip6, udp: Udp) -> port_and_peer.allocation_port(), port_and_peer.peer_port(), new_udp_len, + cd.number(), + cd.length(), ); remove_channel_data_header_ipv6(ctx)?; diff --git a/rust/relay/ebpf-turn-router/src/udp.rs b/rust/relay/ebpf-turn-router/src/udp.rs index eb201ef09..161bb2d11 100644 --- a/rust/relay/ebpf-turn-router/src/udp.rs +++ b/rust/relay/ebpf-turn-router/src/udp.rs @@ -42,11 +42,23 @@ impl<'a> Udp<'a> { new_src: u16, new_dst: u16, new_len: u16, + channel_number: u16, + channel_data_len: u16, ) { let src = self.src(); let dst = self.dst(); let len = self.len(); + let payload_checksum_update = if new_len > len { + ChecksumUpdate::default() + .add_u16(channel_number) + .add_u16(channel_data_len) + } else { + ChecksumUpdate::default() + .remove_u16(channel_number) + .remove_u16(channel_data_len) + }; + self.inner.source = new_src.to_be_bytes(); self.inner.dest = new_dst.to_be_bytes(); self.inner.len = new_len.to_be_bytes(); @@ -56,6 +68,7 @@ impl<'a> Udp<'a> { if crate::config::udp_checksum_enabled() { self.inner.check = ChecksumUpdate::new(u16::from_be_bytes(self.inner.check)) .add_update(ip_pseudo_header) + .add_update(payload_checksum_update) .remove_u16(len) .add_u16(new_len) .remove_u16(src)