From 220c9ee1e100737314ea038aa0d475ac911d6f05 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 26 Feb 2024 14:40:09 +1100 Subject: [PATCH] fix(connlib): correctly handle GRO (#3732) With the use of `quinn-udp`, we are actually already using GRO for reading packets from the UDP socket. Especially during a test like iperf, it is thus very likely to read multiple packets from the same peer in a single syscall. In that case, `stride` tells us how they are split. Without handling `stride` correctly, we would be feeding multiple packets at once to boringtun which would (obviously) choke on it because its checksum verification fails. It turns out we can actually handle this quite nicely by returning an `Iterator` and decapsulating them one-by-one. --- .github/workflows/ci.yml | 6 +++ rust/connlib/tunnel/src/lib.rs | 84 +++++++++++++++--------------- rust/connlib/tunnel/src/sockets.rs | 19 ++++--- 3 files changed, 60 insertions(+), 49 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95568cefe..c516f48c8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -288,12 +288,18 @@ jobs: - name: Show Client logs if: "!cancelled()" run: docker compose logs client + - name: Show Client UDP stats + if: "!cancelled()" + run: docker compose exec client cat /proc/net/udp - name: Show Relay logs if: "!cancelled()" run: docker compose logs relay - name: Show Gateway logs if: "!cancelled()" run: docker compose logs gateway + - name: Show Gateway UDP stats + if: "!cancelled()" + run: docker compose exec gateway cat /proc/net/udp - name: Show API logs if: "!cancelled()" run: docker compose logs api diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 1e2b7c730..20439c795 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -19,6 +19,7 @@ use std::{ collections::{HashMap, HashSet}, fmt, hash::Hash, + io, net::IpAddr, sync::Arc, task::{ready, Context, Poll}, @@ -93,9 +94,8 @@ where _ => (), } - match self.connections_state.poll_sockets(cx) { - Poll::Ready(packet) => { - device.write(packet)?; + match self.connections_state.poll_sockets(device, cx)? { + Poll::Ready(()) => { cx.waker().wake_by_ref(); } Poll::Pending => {} @@ -158,9 +158,8 @@ where _ => (), } - match self.connections_state.poll_sockets(cx) { - Poll::Ready(packet) => { - device.write(packet)?; + match self.connections_state.poll_sockets(device, cx)? { + Poll::Ready(()) => { cx.waker().wake_by_ref(); } Poll::Pending => {} @@ -295,7 +294,7 @@ where Ok(()) } - fn poll_sockets<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_sockets(&mut self, device: &mut Device, cx: &mut Context<'_>) -> Poll> { let received = match ready!(self.sockets.poll_recv_from(cx)) { Ok(received) => received, Err(e) => { @@ -306,54 +305,55 @@ where } }; - let Received { - local, - from, - packet, - } = received; + for received in received { + let Received { + local, + from, + packet, + } = received; - let (conn_id, packet) = match self.node.decapsulate( - local, - from, - packet, - std::time::Instant::now(), - self.write_buf.as_mut(), - ) { - Ok(Some(packet)) => packet, - Ok(None) => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - Err(e) => { - tracing::warn!(%local, %from, "Failed to decapsulate incoming packet: {e}"); + let (conn_id, packet) = match self.node.decapsulate( + local, + from, + packet.as_ref(), + std::time::Instant::now(), + self.write_buf.as_mut(), + ) { + Ok(Some(packet)) => packet, + Ok(None) => { + continue; + } + Err(e) => { + tracing::warn!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"); - cx.waker().wake_by_ref(); - return Poll::Pending; - } - }; + continue; + } + }; - tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet"); + tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet"); - let Some(peer) = self.peers_by_id.get(&conn_id) else { - tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); + let Some(peer) = self.peers_by_id.get(&conn_id) else { + tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); - cx.waker().wake_by_ref(); - return Poll::Pending; - }; + continue; + }; - let packet_len = packet.packet().len(); - let packet = - match peer.untransform(packet.source(), &mut self.write_buf.as_mut()[..packet_len]) { + let packet_len = packet.packet().len(); + let packet = match peer + .untransform(packet.source(), &mut self.write_buf.as_mut()[..packet_len]) + { Ok(packet) => packet, Err(e) => { tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}"); - cx.waker().wake_by_ref(); - return Poll::Pending; + continue; } }; - Poll::Ready(packet) + device.write(packet)?; + } + + Poll::Ready(Ok(())) } fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 890cfcadf..a9b1e3448 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -98,7 +98,7 @@ impl Sockets { pub fn poll_recv_from<'a>( &'a mut self, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>>> { if let Some(Poll::Ready(packet)) = self.socket_v4.as_mut().map(|s| s.poll_recv_from(cx)) { return Poll::Ready(packet); } @@ -150,7 +150,10 @@ impl Socket { } #[allow(clippy::type_complexity)] - fn poll_recv_from<'b>(&'b mut self, cx: &mut Context<'_>) -> Poll>> { + fn poll_recv_from<'b>( + &'b mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { let Socket { port, socket, @@ -180,11 +183,13 @@ impl Socket { let local = SocketAddr::new(local_ip, *port); - return Poll::Ready(Ok(Received { - local, - from: meta.addr, - packet: &mut buffer[..meta.len], - })); + return Poll::Ready(Ok(buffer[..meta.len].chunks(meta.stride).map( + move |packet| Received { + local, + from: meta.addr, + packet, + }, + ))); } } }