diff --git a/rust/connection-tests/src/main.rs b/rust/connection-tests/src/main.rs index 7d20a04cd..aec8fc2b7 100644 --- a/rust/connection-tests/src/main.rs +++ b/rust/connection-tests/src/main.rs @@ -12,9 +12,9 @@ use firezone_connection::{ Answer, ClientConnectionPool, ConnectionPool, Credentials, IpPacket, Offer, ServerConnectionPool, }; -use futures::{future::BoxFuture, FutureExt}; +use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt}; use pnet_packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet}; -use redis::AsyncCommands; +use redis::{aio::MultiplexedConnection, AsyncCommands}; use secrecy::{ExposeSecret as _, Secret}; use tokio::{io::ReadBuf, net::UdpSocket}; use tracing_subscriber::EnvFilter; @@ -67,7 +67,7 @@ async fn main() -> Result<()> { let redis_host = std::env::var("REDIS_HOST").context("Missing REDIS_HOST env var")?; let redis_client = redis::Client::open(format!("redis://{redis_host}:6379"))?; - let mut redis_connection = redis_client.get_async_connection().await?; + let mut redis_connection = redis_client.get_multiplexed_async_connection().await?; let socket = UdpSocket::bind((listen_addr, 0)).await?; let socket_addr = socket.local_addr()?; @@ -119,46 +119,45 @@ async fn main() -> Result<()> { }, ); - let mut eventloop = Eventloop::new(socket, pool); + let rx = spawn_candidate_task(redis_connection.clone(), "listener_candidates"); + + let mut eventloop = Eventloop::new(socket, pool, rx); let ping_body = rand::random::<[u8; 32]>(); let mut start = Instant::now(); loop { - tokio::select! { - event = poll_fn(|cx| eventloop.poll(cx)) => { - match event? { - Event::Incoming { conn, packet } => { - anyhow::ensure!(conn == 1); - anyhow::ensure!(packet == IpPacket::Ipv4(ip4_udp_ping_packet(dst, source, packet.udp_payload()))); // Expect the listener to flip src and dst + match poll_fn(|cx| eventloop.poll(cx)).await? { + Event::Incoming { conn, packet } => { + anyhow::ensure!(conn == 1); + anyhow::ensure!( + packet + == IpPacket::Ipv4(ip4_udp_ping_packet( + dst, + source, + packet.udp_payload() + )) + ); // Expect the listener to flip src and dst - let rtt = start.elapsed(); + let rtt = start.elapsed(); - tracing::info!("RTT is {rtt:?}"); + tracing::info!("RTT is {rtt:?}"); - return Ok(()) - } - Event::SignalIceCandidate { conn, candidate } => { - redis_connection - .rpush("dialer_candidates", wire::Candidate { conn, candidate }) - .await - .context("Failed to push candidate")?; - } - Event::ConnectionEstablished { conn } => { - start = Instant::now(); - eventloop.send_to(conn, ip4_udp_ping_packet(source, dst, &ping_body).into())?; - } - Event::ConnectionFailed { conn } => { - anyhow::bail!("Failed to establish connection: {conn}"); - } - } + return Ok(()); } - - response = redis_connection.blpop::<_, Option<(String, wire::Candidate)>>("listener_candidates", 1.0) => { - let Ok(Some((_, wire::Candidate { conn, candidate }))) = response else { - continue; - }; - eventloop.pool.add_remote_candidate(conn, candidate); + Event::SignalIceCandidate { conn, candidate } => { + redis_connection + .rpush("dialer_candidates", wire::Candidate { conn, candidate }) + .await + .context("Failed to push candidate")?; + } + Event::ConnectionEstablished { conn } => { + start = Instant::now(); + eventloop + .send_to(conn, ip4_udp_ping_packet(source, dst, &ping_body).into())?; + } + Event::ConnectionFailed { conn } => { + anyhow::bail!("Failed to establish connection: {conn}"); } } } @@ -199,33 +198,27 @@ async fn main() -> Result<()> { .await .context("Failed to push answer")?; - let mut eventloop = Eventloop::new(socket, pool); + let rx = spawn_candidate_task(redis_connection.clone(), "dialer_candidates"); + + let mut eventloop = Eventloop::new(socket, pool, rx); loop { - tokio::select! { - event = poll_fn(|cx| eventloop.poll(cx)) => { - match event? { - Event::Incoming { conn, packet } => { - eventloop.send_to(conn, ip4_udp_ping_packet(dst, source, packet.udp_payload()).into())?; - } - Event::SignalIceCandidate { conn, candidate } => { - redis_connection - .rpush("listener_candidates", wire::Candidate { conn, candidate }) - .await - .context("Failed to push candidate")?; - } - Event::ConnectionEstablished { .. } => { } - Event::ConnectionFailed { conn } => { - anyhow::bail!("Failed to establish connection: {conn}"); - } - } + match poll_fn(|cx| eventloop.poll(cx)).await? { + Event::Incoming { conn, packet } => { + eventloop.send_to( + conn, + ip4_udp_ping_packet(dst, source, packet.udp_payload()).into(), + )?; } - - response = redis_connection.blpop::<_, Option<(String, wire::Candidate)>>("dialer_candidates", 1.0) => { - let Ok(Some((_, wire::Candidate { conn, candidate }))) = response else { - continue; - }; - eventloop.pool.add_remote_candidate(conn, candidate); + Event::SignalIceCandidate { conn, candidate } => { + redis_connection + .rpush("listener_candidates", wire::Candidate { conn, candidate }) + .await + .context("Failed to push candidate")?; + } + Event::ConnectionEstablished { .. } => {} + Event::ConnectionFailed { conn } => { + anyhow::bail!("Failed to establish connection: {conn}"); } } } @@ -233,6 +226,27 @@ async fn main() -> Result<()> { }; } +fn spawn_candidate_task( + mut conn: MultiplexedConnection, + topic: &'static str, +) -> mpsc::Receiver { + let (mut sender, receiver) = mpsc::channel(0); + tokio::spawn(async move { + loop { + let candidate = conn + .blpop::<_, Option<(String, wire::Candidate)>>(topic, 1.0) + .await + .unwrap(); + + if let Some((_, candidate)) = candidate { + sender.send(candidate).await.unwrap(); + } + } + }); + + receiver +} + fn ip4_udp_ping_packet(source: Ipv4Addr, dst: Ipv4Addr, body: &[u8]) -> Ipv4Packet<'static> { assert_eq!(body.len(), 32); @@ -295,6 +309,7 @@ mod wire { serde::Deserialize, redis_macros::FromRedisValue, redis_macros::ToRedisArgs, + Debug, )] pub struct Candidate { pub conn: u64, @@ -323,18 +338,24 @@ struct Eventloop { socket: UdpSocket, pool: ConnectionPool, timeout: BoxFuture<'static, Instant>, + candidate_rx: mpsc::Receiver, read_buffer: Box<[u8; MAX_UDP_SIZE]>, write_buffer: Box<[u8; MAX_UDP_SIZE]>, } impl Eventloop { - fn new(socket: UdpSocket, pool: ConnectionPool) -> Self { + fn new( + socket: UdpSocket, + pool: ConnectionPool, + candidate_rx: mpsc::Receiver, + ) -> Self { Self { socket, pool, timeout: sleep_until(Instant::now()).boxed(), read_buffer: Box::new([0u8; MAX_UDP_SIZE]), write_buffer: Box::new([0u8; MAX_UDP_SIZE]), + candidate_rx, } } @@ -376,6 +397,15 @@ impl Eventloop { None => {} } + if let Poll::Ready(Some(wire::Candidate { conn, candidate })) = + self.candidate_rx.poll_next_unpin(cx) + { + self.pool.add_remote_candidate(conn, candidate); + + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if let Poll::Ready(instant) = self.timeout.poll_unpin(cx) { self.pool.handle_timeout(instant); if let Some(timeout) = self.pool.poll_timeout() { @@ -404,6 +434,9 @@ impl Eventloop { packet: packet.to_owned(), })); } + + cx.waker().wake_by_ref(); + return Poll::Pending; } Poll::Pending