fix(connlib): correctly handle GRO (#3732)

With the use of `quinn-udp`, we are actually already using GRO for
reading packets from the UDP socket. Especially during a test like
iperf, it is thus very likely to read multiple packets from the same
peer in a single syscall. In that case, `stride` tells us how they are
split.

Without handling `stride` correctly, we would be feeding multiple
packets at once to boringtun which would (obviously) choke on it because
its checksum verification fails.

It turns out we can actually handle this quite nicely by returning an
`Iterator<Item = Received>` and decapsulating them one-by-one.
This commit is contained in:
Thomas Eizinger
2024-02-26 14:40:09 +11:00
committed by GitHub
parent 0ded6ad79d
commit 220c9ee1e1
3 changed files with 60 additions and 49 deletions

View File

@@ -288,12 +288,18 @@ jobs:
- name: Show Client logs
if: "!cancelled()"
run: docker compose logs client
- name: Show Client UDP stats
if: "!cancelled()"
run: docker compose exec client cat /proc/net/udp
- name: Show Relay logs
if: "!cancelled()"
run: docker compose logs relay
- name: Show Gateway logs
if: "!cancelled()"
run: docker compose logs gateway
- name: Show Gateway UDP stats
if: "!cancelled()"
run: docker compose exec gateway cat /proc/net/udp
- name: Show API logs
if: "!cancelled()"
run: docker compose logs api

View File

@@ -19,6 +19,7 @@ use std::{
collections::{HashMap, HashSet},
fmt,
hash::Hash,
io,
net::IpAddr,
sync::Arc,
task::{ready, Context, Poll},
@@ -93,9 +94,8 @@ where
_ => (),
}
match self.connections_state.poll_sockets(cx) {
Poll::Ready(packet) => {
device.write(packet)?;
match self.connections_state.poll_sockets(device, cx)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
Poll::Pending => {}
@@ -158,9 +158,8 @@ where
_ => (),
}
match self.connections_state.poll_sockets(cx) {
Poll::Ready(packet) => {
device.write(packet)?;
match self.connections_state.poll_sockets(device, cx)? {
Poll::Ready(()) => {
cx.waker().wake_by_ref();
}
Poll::Pending => {}
@@ -295,7 +294,7 @@ where
Ok(())
}
fn poll_sockets<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<device_channel::Packet<'a>> {
fn poll_sockets(&mut self, device: &mut Device, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let received = match ready!(self.sockets.poll_recv_from(cx)) {
Ok(received) => received,
Err(e) => {
@@ -306,54 +305,55 @@ where
}
};
let Received {
local,
from,
packet,
} = received;
for received in received {
let Received {
local,
from,
packet,
} = received;
let (conn_id, packet) = match self.node.decapsulate(
local,
from,
packet,
std::time::Instant::now(),
self.write_buf.as_mut(),
) {
Ok(Some(packet)) => packet,
Ok(None) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
Err(e) => {
tracing::warn!(%local, %from, "Failed to decapsulate incoming packet: {e}");
let (conn_id, packet) = match self.node.decapsulate(
local,
from,
packet.as_ref(),
std::time::Instant::now(),
self.write_buf.as_mut(),
) {
Ok(Some(packet)) => packet,
Ok(None) => {
continue;
}
Err(e) => {
tracing::warn!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}");
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
continue;
}
};
tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet");
tracing::trace!(target: "wire", %local, %from, bytes = %packet.packet().len(), "read new packet");
let Some(peer) = self.peers_by_id.get(&conn_id) else {
tracing::error!(%conn_id, %local, %from, "Couldn't find connection");
let Some(peer) = self.peers_by_id.get(&conn_id) else {
tracing::error!(%conn_id, %local, %from, "Couldn't find connection");
cx.waker().wake_by_ref();
return Poll::Pending;
};
continue;
};
let packet_len = packet.packet().len();
let packet =
match peer.untransform(packet.source(), &mut self.write_buf.as_mut()[..packet_len]) {
let packet_len = packet.packet().len();
let packet = match peer
.untransform(packet.source(), &mut self.write_buf.as_mut()[..packet_len])
{
Ok(packet) => packet,
Err(e) => {
tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}");
cx.waker().wake_by_ref();
return Poll::Pending;
continue;
}
};
Poll::Ready(packet)
device.write(packet)?;
}
Poll::Ready(Ok(()))
}
fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Event<TId>> {

View File

@@ -98,7 +98,7 @@ impl Sockets {
pub fn poll_recv_from<'a>(
&'a mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<Received<'a>>> {
) -> Poll<io::Result<impl Iterator<Item = Received<'a>>>> {
if let Some(Poll::Ready(packet)) = self.socket_v4.as_mut().map(|s| s.poll_recv_from(cx)) {
return Poll::Ready(packet);
}
@@ -150,7 +150,10 @@ impl<const N: usize> Socket<N> {
}
#[allow(clippy::type_complexity)]
fn poll_recv_from<'b>(&'b mut self, cx: &mut Context<'_>) -> Poll<io::Result<Received<'b>>> {
fn poll_recv_from<'b>(
&'b mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<impl Iterator<Item = Received<'b>>>> {
let Socket {
port,
socket,
@@ -180,11 +183,13 @@ impl<const N: usize> Socket<N> {
let local = SocketAddr::new(local_ip, *port);
return Poll::Ready(Ok(Received {
local,
from: meta.addr,
packet: &mut buffer[..meta.len],
}));
return Poll::Ready(Ok(buffer[..meta.len].chunks(meta.stride).map(
move |packet| Received {
local,
from: meta.addr,
packet,
},
)));
}
}
}