diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0318f4d05..7a9176f76 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,6 +63,11 @@ jobs: context: rust build-args: | PACKAGE=firezone-linux-client + - image_name: connection-tests + target: debug + context: rust + build-args: | + PACKAGE=firezone-connection-tests - image_name: elixir target: compiler context: elixir @@ -134,6 +139,28 @@ jobs: target: ${{ matrix.target }} tags: ${{ steps.build_docker_tags.outputs.tags }} + connection-integration-tests: + needs: build-images + runs-on: ubuntu-22.04 + permissions: + contents: read + id-token: write + pull-requests: write + env: + VERSION: ${{ github.sha }} + strategy: + fail-fast: false + matrix: + file: ['docker-compose.lan.yml'] + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/gcp-docker-login + id: login + with: + project: firezone-staging + - name: Run ${{ matrix.file }} test + run: docker compose -f rust/connection-tests/${{ matrix.file }} up --exit-code-from dialer --abort-on-container-exit + integration-tests: needs: build-images runs-on: ubuntu-22.04 diff --git a/rust/Cargo.lock b/rust/Cargo.lock index c177998e9..41b7cd2b1 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -196,6 +196,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "array-init" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23589ecb866b460d3a0f1278834750268c607e8e28a1b982c907219f3178cd72" +dependencies = [ + "nodrop", +] + [[package]] name = "asn1-rs" version = "0.5.2" @@ -909,7 +918,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3431df59f28accaf4cb4eed4a9acc66bea3f3c3753aa6cdc2f024174ef232af7" dependencies = [ - "smallvec", + "smallvec 1.11.2", ] [[package]] @@ -918,7 +927,7 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03915af431787e6ffdcc74c645077518c6b6e01f80b761e0fbbfa288536311b3" dependencies = [ - "smallvec", + "smallvec 1.11.2", "target-lexicon", ] @@ -1384,7 +1393,7 @@ dependencies = [ "phf 0.8.0", "proc-macro2", "quote", - "smallvec", + "smallvec 1.11.2", "syn 1.0.109", ] @@ -1897,6 +1906,43 @@ dependencies = [ "url", ] +[[package]] +name = "firezone-connection" +version = "1.0.0" +dependencies = [ + "anyhow", + "boringtun", + "firezone-relay", + "pnet_packet", + "rand 0.8.5", + "secrecy", + "str0m", + "thiserror", + "tracing", +] + +[[package]] +name = "firezone-connection-tests" +version = "1.0.0" +dependencies = [ + "anyhow", + "boringtun", + "firezone-connection", + "futures", + "hex", + "pnet_packet", + "rand 0.8.5", + "redis", + "redis-macros", + "secrecy", + "serde", + "serde-hex", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "firezone-gateway" version = "1.0.0" @@ -2450,7 +2496,7 @@ dependencies = [ "gobject-sys", "libc", "once_cell", - "smallvec", + "smallvec 1.11.2", "thiserror", ] @@ -2678,7 +2724,7 @@ dependencies = [ "parking_lot", "rand 0.8.5", "resolv-conf", - "smallvec", + "smallvec 1.11.2", "thiserror", "tokio", "tracing", @@ -3423,6 +3469,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maybe-uninit" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00" + [[package]] name = "md-5" version = "0.10.6" @@ -4063,7 +4115,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.11.2", "windows-targets 0.48.5", ] @@ -4737,11 +4789,36 @@ dependencies = [ "percent-encoding", "pin-project-lite", "ryu", + "sha1_smol", + "socket2 0.4.10", "tokio", "tokio-util", "url", ] +[[package]] +name = "redis-macros" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60eb39e2b44d4c0f9c84e7c5fc4fc3adc8dd26ec48f1ac3a160033f7c03b18fd" +dependencies = [ + "redis", + "redis-macros-derive", + "serde", + "serde_json", +] + +[[package]] +name = "redis-macros-derive" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39550b9e94ce430a349c5490ca4efcae90ab8189603320f88c1d69f0326f169e" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -5131,6 +5208,21 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "sctp-proto" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514eb06a3f6b1f119b6a00a9a87afac072894817d3283b0d36adc8f8a135886a" +dependencies = [ + "bytes", + "crc", + "fxhash", + "log", + "rand 0.8.5", + "slab", + "thiserror", +] + [[package]] name = "sdp" version = "0.6.0" @@ -5225,7 +5317,7 @@ dependencies = [ "phf_codegen 0.8.0", "precomputed-hash", "servo_arc", - "smallvec", + "smallvec 1.11.2", "thin-slice", ] @@ -5247,6 +5339,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-hex" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca37e3e4d1b39afd7ff11ee4e947efae85adfddf4841787bfa47c470e96dc26d" +dependencies = [ + "array-init", + "serde", + "smallvec 0.6.14", +] + [[package]] name = "serde_derive" version = "1.0.193" @@ -5362,6 +5465,18 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "sha-1" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", + "sha1-asm", +] + [[package]] name = "sha1" version = "0.10.6" @@ -5373,6 +5488,21 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1-asm" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ba6947745e7f86be3b8af00b7355857085dbdf8901393c89514510eb61f4e21" +dependencies = [ + "cc", +] + +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.8" @@ -5439,6 +5569,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "smallvec" +version = "0.6.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97fcaeba89edba30f044a10c6a3cc39df9c3f17d7cd829dd1446cab35f890e0" +dependencies = [ + "maybe-uninit", +] + [[package]] name = "smallvec" version = "1.11.2" @@ -5551,6 +5690,23 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" +[[package]] +name = "str0m" +version = "0.4.1" +source = "git+https://github.com/thomaseizinger/str0m?branch=main#12575cab1c8c466cbf1e05b0b7459136eeaba8ed" +dependencies = [ + "combine", + "crc", + "hmac", + "once_cell", + "rand 0.8.5", + "sctp-proto", + "serde", + "sha-1", + "thiserror", + "tracing", +] + [[package]] name = "string_cache" version = "0.8.7" @@ -6496,7 +6652,7 @@ dependencies = [ "once_cell", "opentelemetry", "opentelemetry_sdk", - "smallvec", + "smallvec 1.11.2", "tracing", "tracing-core", "tracing-log 0.1.4", @@ -6559,7 +6715,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.11.2", "thread_local", "tracing", "tracing-core", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3b50fea8a..e330369c6 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -5,9 +5,11 @@ members = [ "connlib/clients/shared", "connlib/shared", "connlib/tunnel", + "connlib/connection", "gateway", "linux-client", "firezone-cli-utils", + "connection-tests", "phoenix-channel", "relay", "windows-client/src-tauri", @@ -25,6 +27,7 @@ tracing-subscriber = { version = "0.3.17", features = ["parking_lot"] } secrecy = "0.8" hickory-resolver = { version = "0.24", features = ["tokio-runtime"] } webrtc = "0.9" +str0m = "0.4" futures-bounded = "0.2.1" domain = { version = "0.9", features = ["serde"] } @@ -35,6 +38,8 @@ firezone-gateway = { path = "gateway"} firezone-linux-client = { path = "linux-client"} firezone-windows-client = { path = "windows-client/src-tauri"} firezone-cli-utils = { path = "firezone-cli-utils"} +firezone-connection = { path = "connlib/connection"} +firezone-relay = { path = "relay"} connlib-shared = { path = "connlib/shared"} firezone-tunnel = { path = "connlib/tunnel"} phoenix-channel = { path = "phoenix-channel"} @@ -42,6 +47,7 @@ phoenix-channel = { path = "phoenix-channel"} [patch.crates-io] boringtun = { git = "https://github.com/cloudflare/boringtun", branch = "master" } # Contains unreleased patches we need (bump of x25519-dalek) webrtc = { git = "https://github.com/firezone/webrtc", branch = "expose-new-endpoint" } +str0m = { git = "https://github.com/thomaseizinger/str0m", branch = "main" } [profile.release] strip = true diff --git a/rust/connection-tests/Cargo.toml b/rust/connection-tests/Cargo.toml new file mode 100644 index 000000000..4bf19897b --- /dev/null +++ b/rust/connection-tests/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "firezone-connection-tests" +# mark:automatic-version +version = "1.0.0" +edition = "2021" + +[dependencies] +anyhow = "1" +boringtun = { workspace = true } +firezone-connection = { workspace = true } +futures = "0.3" +hex = "0.4" +pnet_packet = { version = "0.34" } +rand = "0.8" +redis = { version = "0.23.3", default-features = false, features = ["tokio-comp"] } +redis-macros = "0.2.1" +secrecy = { workspace = true } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +serde-hex = "0.1.0" +tokio = { version = "1", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } diff --git a/rust/connection-tests/README.md b/rust/connection-tests/README.md new file mode 100644 index 000000000..fe10e80b7 --- /dev/null +++ b/rust/connection-tests/README.md @@ -0,0 +1,28 @@ +# firezone-connection integration tests + +This directory contains Docker-based integration tests for the `firezone-connection` crate. +Each integration test setup is a dedicated docker-compose file. + +## Running + +To run one of these tests, use the following command: + +```shell +sudo docker compose -f ./docker-compose.lan.yml up --exit-code-from dialer --abort-on-container-exit --build +``` + +This will force a re-build of the containers and exit with 0 if everything works correctly. + +## Design + +Each file consists of at least: + +- A dialer +- A listener +- A redis server + +Redis acts as the signalling channel. +Dialer and listener use it to exchange offers & answers as well as ICE candidates. + +The various files simulate different network environments. +We use nftables to simulate NATs and / or force the use of TURN servers. diff --git a/rust/connection-tests/docker-compose.lan.yml b/rust/connection-tests/docker-compose.lan.yml new file mode 100644 index 000000000..4a2026419 --- /dev/null +++ b/rust/connection-tests/docker-compose.lan.yml @@ -0,0 +1,101 @@ +version: "3.8" + +services: + dialer: + build: + target: debug + context: .. + args: + PACKAGE: firezone-connection-tests + cache_from: + - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main + image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main} + init: true + environment: + ROLE: "dialer" + LISTEN_ADDR: 172.28.0.100 + REDIS_HOST: redis # All services share the `app` network. + cap_add: + - NET_ADMIN + # depends_on: + # relay: + # condition: "service_healthy" + # redis: + # condition: "service_healthy" + networks: + app: + ipv4_address: 172.28.0.100 + + listener: + build: + target: debug + context: .. + args: + PACKAGE: firezone-connection-tests + cache_from: + - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main + image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main} + init: true + environment: + ROLE: "listener" + LISTEN_ADDR: 172.28.0.101 + REDIS_HOST: redis # All services share the `app` network. + cap_add: + - NET_ADMIN + # depends_on: + # relay: + # condition: "service_healthy" + # redis: + # condition: "service_healthy" + networks: + app: + ipv4_address: 172.28.0.101 + + # relay: + # environment: + # PUBLIC_IP4_ADDR: 172.28.0.102 + # PUBLIC_IP6_ADDR: fcff:3990:3990::101 + # LOWEST_PORT: 55555 + # HIGHEST_PORT: 55666 + # RUST_LOG: "debug" + # RUST_BACKTRACE: 1 + # build: + # target: debug + # context: .. + # cache_from: + # - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/relay:main + # args: + # PACKAGE: firezone-relay + # init: true + # healthcheck: + # test: ["CMD-SHELL", "lsof -i UDP | grep firezone-relay"] + # start_period: 20s + # interval: 30s + # retries: 5 + # timeout: 5s + # ports: + # # XXX: Only 111 ports are used for local dev / testing because Docker Desktop + # # allocates a userland proxy process for each forwarded port X_X. + # # + # # Large ranges here will bring your machine to its knees. + # - "55555-55666:55555-55666/udp" + # - 3478:3478/udp + # networks: + # app: + # ipv4_address: 172.28.0.102 + + redis: + image: "redis:7-alpine" + # healthcheck: + # test: ["CMD-SHELL", "echo 'ready';"] + networks: + app: + ipv4_address: 172.28.0.103 + +networks: + app: + enable_ipv6: true + ipam: + config: + - subnet: 172.28.0.0/16 + - subnet: 2001:db8:2::/64 diff --git a/rust/connection-tests/src/main.rs b/rust/connection-tests/src/main.rs new file mode 100644 index 000000000..e756f409f --- /dev/null +++ b/rust/connection-tests/src/main.rs @@ -0,0 +1,389 @@ +use std::{ + future::poll_fn, + net::{IpAddr, Ipv4Addr}, + str::FromStr, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use anyhow::{bail, Context as _, Result}; +use boringtun::x25519::{PublicKey, StaticSecret}; +use firezone_connection::{ + Answer, ClientConnectionPool, ConnectionPool, Credentials, IpPacket, Offer, + ServerConnectionPool, +}; +use futures::{future::BoxFuture, FutureExt}; +use pnet_packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet}; +use redis::AsyncCommands; +use secrecy::{ExposeSecret as _, Secret}; +use tokio::{io::ReadBuf, net::UdpSocket}; +use tracing_subscriber::EnvFilter; + +const MAX_UDP_SIZE: usize = (1 << 16) - 1; + +#[tokio::main] +async fn main() -> Result<()> { + tokio::time::sleep(Duration::from_secs(1)).await; // Until redis is up. + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::builder().parse("info,boringtun=debug")?) + .init(); + + let role = std::env::var("ROLE") + .context("Missing ROLE env variable")? + .parse::()?; + let listen_addr = std::env::var("LISTEN_ADDR") + .context("Missing LISTEN_ADDR env var")? + .parse::()?; + let redis_host = std::env::var("REDIS_HOST").context("Missing LISTEN_ADDR env var")?; + + let redis_client = redis::Client::open(format!("redis://{redis_host}:6379"))?; + let mut redis_connection = redis_client.get_async_connection().await?; + + let socket = UdpSocket::bind((listen_addr, 0)).await?; + let socket_addr = socket.local_addr()?; + let private_key = StaticSecret::random_from_rng(&mut rand::thread_rng()); + let public_key = PublicKey::from(&private_key); + + // The source and dst of our dummy IP packet that we send via the wireguard tunnel. + let source = Ipv4Addr::new(172, 16, 0, 1); + let dst = Ipv4Addr::new(10, 0, 0, 1); + + match role { + Role::Dialer => { + let mut pool = ClientConnectionPool::::new(private_key); + pool.add_local_interface(socket_addr); + + let offer = pool.new_connection(1, vec![], vec![]); + + redis_connection + .rpush( + "offers", + wire::Offer { + session_key: *offer.session_key.expose_secret(), + username: offer.credentials.username, + password: offer.credentials.password, + public_key: public_key.to_bytes(), + }, + ) + .await + .context("Failed to push offer")?; + + let answer = redis_connection + .blpop::<_, (String, wire::Answer)>("answers", 10) + .await + .context("Failed to pop answer")? + .1; + + pool.accept_answer( + 1, + answer.public_key.into(), + Answer { + credentials: Credentials { + username: answer.username, + password: answer.password, + }, + }, + ); + + let mut eventloop = Eventloop::new(socket, pool); + + let ping_body = rand::random::<[u8; 32]>(); + let mut start = Instant::now(); + + loop { + tokio::select! { + event = poll_fn(|cx| eventloop.poll(cx)) => { + match event? { + Event::Incoming { conn, packet } => { + anyhow::ensure!(conn == 1); + anyhow::ensure!(packet == IpPacket::Ipv4(ip4_udp_ping_packet(dst, source, packet.udp_payload()))); // Expect the listener to flip src and dst + + let rtt = start.elapsed(); + + tracing::info!("RTT is {rtt:?}"); + + return Ok(()) + } + Event::SignalIceCandidate { conn, candidate } => { + redis_connection + .rpush("dialer_candidates", wire::Candidate { conn, candidate }) + .await + .context("Failed to push candidate")?; + } + Event::ConnectionEstablished { conn } => { + start = Instant::now(); + eventloop.send_to(conn, ip4_udp_ping_packet(source, dst, &ping_body).into())?; + } + } + } + + response = redis_connection.blpop::<_, Option<(String, wire::Candidate)>>("listener_candidates", 1) => { + let Ok(Some((_, wire::Candidate { conn, candidate }))) = response else { + continue; + }; + eventloop.pool.add_remote_candidate(conn, candidate); + } + } + } + } + Role::Listener => { + let mut pool = ServerConnectionPool::::new(private_key); + pool.add_local_interface(socket_addr); + + let offer = redis_connection + .blpop::<_, (String, wire::Offer)>("offers", 10) + .await + .context("Failed to pop offer")? + .1; + + let answer = pool.accept_connection( + 1, + Offer { + session_key: Secret::new(offer.session_key), + credentials: Credentials { + username: offer.username, + password: offer.password, + }, + }, + offer.public_key.into(), + vec![], + vec![], + ); + + redis_connection + .rpush( + "answers", + wire::Answer { + public_key: public_key.to_bytes(), + username: answer.credentials.username, + password: answer.credentials.password, + }, + ) + .await + .context("Failed to push answer")?; + + let mut eventloop = Eventloop::new(socket, pool); + + loop { + tokio::select! { + event = poll_fn(|cx| eventloop.poll(cx)) => { + match event? { + Event::Incoming { conn, packet } => { + eventloop.send_to(conn, ip4_udp_ping_packet(dst, source, packet.udp_payload()).into())?; + } + Event::SignalIceCandidate { conn, candidate } => { + redis_connection + .rpush("listener_candidates", wire::Candidate { conn, candidate }) + .await + .context("Failed to push candidate")?; + } + Event::ConnectionEstablished { .. } => { } + } + } + + response = redis_connection.blpop::<_, Option<(String, wire::Candidate)>>("dialer_candidates", 1) => { + let Ok(Some((_, wire::Candidate { conn, candidate }))) = response else { + continue; + }; + eventloop.pool.add_remote_candidate(conn, candidate); + } + } + } + } + }; +} + +fn ip4_udp_ping_packet(source: Ipv4Addr, dst: Ipv4Addr, body: &[u8]) -> Ipv4Packet<'static> { + assert_eq!(body.len(), 32); + + let mut packet_buffer = [0u8; 60]; + + let mut ip4_header = + pnet_packet::ipv4::MutableIpv4Packet::new(&mut packet_buffer[..20]).unwrap(); + ip4_header.set_version(4); + ip4_header.set_source(source); + ip4_header.set_destination(dst); + ip4_header.set_next_level_protocol(IpNextHeaderProtocols::Udp); + ip4_header.set_ttl(10); + ip4_header.set_total_length(20 + 8 + 32); // IP4 + UDP + payload. + ip4_header.set_header_length(5); // Length is in number of 32bit words, i.e. 5 means 20 bytes. + ip4_header.set_checksum(pnet_packet::ipv4::checksum(&ip4_header.to_immutable())); + + let mut udp_header = + pnet_packet::udp::MutableUdpPacket::new(&mut packet_buffer[20..28]).unwrap(); + udp_header.set_source(9999); + udp_header.set_destination(9999); + udp_header.set_length(8 + 32); + udp_header.set_checksum(0); // Not necessary for IPv4, let's keep it simple. + + packet_buffer[28..60].copy_from_slice(body); + + Ipv4Packet::owned(packet_buffer.to_vec()).unwrap() +} + +mod wire { + #[derive( + serde::Serialize, + serde::Deserialize, + redis_macros::FromRedisValue, + redis_macros::ToRedisArgs, + )] + pub struct Offer { + #[serde(with = "serde_hex::SerHex::")] + pub session_key: [u8; 32], + #[serde(with = "serde_hex::SerHex::")] + pub public_key: [u8; 32], + pub username: String, + pub password: String, + } + + #[derive( + serde::Serialize, + serde::Deserialize, + redis_macros::FromRedisValue, + redis_macros::ToRedisArgs, + )] + pub struct Answer { + #[serde(with = "serde_hex::SerHex::")] + pub public_key: [u8; 32], + pub username: String, + pub password: String, + } + + #[derive( + serde::Serialize, + serde::Deserialize, + redis_macros::FromRedisValue, + redis_macros::ToRedisArgs, + )] + pub struct Candidate { + pub conn: u64, + pub candidate: String, + } +} + +enum Role { + Dialer, + Listener, +} + +impl FromStr for Role { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "dialer" => Ok(Self::Dialer), + "listener" => Ok(Self::Listener), + other => bail!("unknown role: {other}"), + } + } +} + +struct Eventloop { + socket: UdpSocket, + pool: ConnectionPool, + timeout: BoxFuture<'static, Instant>, + read_buffer: Box<[u8; MAX_UDP_SIZE]>, + write_buffer: Box<[u8; MAX_UDP_SIZE]>, +} + +impl Eventloop { + fn new(socket: UdpSocket, pool: ConnectionPool) -> Self { + Self { + socket, + pool, + timeout: sleep_until(Instant::now()).boxed(), + read_buffer: Box::new([0u8; MAX_UDP_SIZE]), + write_buffer: Box::new([0u8; MAX_UDP_SIZE]), + } + } + + fn send_to(&mut self, id: u64, packet: IpPacket<'_>) -> Result<()> { + let Some((addr, msg)) = self.pool.encapsulate(id, packet)? else { + return Ok(()); + }; + + tracing::trace!(target = "wire::out", to = %addr, packet = %hex::encode(msg)); + + self.socket.try_send_to(msg, addr)?; + + Ok(()) + } + + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + while let Some(transmit) = self.pool.poll_transmit() { + tracing::trace!(target = "wire::out", to = %transmit.dst, packet = %hex::encode(&transmit.payload)); + + self.socket.try_send_to(&transmit.payload, transmit.dst)?; + } + + match self.pool.poll_event() { + Some(firezone_connection::Event::SignalIceCandidate { + connection, + candidate, + }) => { + return Poll::Ready(Ok(Event::SignalIceCandidate { + conn: connection, + candidate, + })) + } + Some(firezone_connection::Event::ConnectionEstablished(conn)) => { + return Poll::Ready(Ok(Event::ConnectionEstablished { conn })) + } + None => {} + } + + if let Poll::Ready(instant) = self.timeout.poll_unpin(cx) { + self.pool.handle_timeout(instant); + if let Some(timeout) = self.pool.poll_timeout() { + self.timeout = sleep_until(timeout).boxed(); + } + + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let mut read_buf = ReadBuf::new(self.read_buffer.as_mut()); + if let Poll::Ready(from) = self.socket.poll_recv_from(cx, &mut read_buf)? { + let packet = read_buf.filled(); + + tracing::trace!(target = "wire::in", %from, packet = %hex::encode(packet)); + + if let Some((conn, packet)) = self.pool.decapsulate( + self.socket.local_addr()?, + from, + packet, + Instant::now(), + self.write_buffer.as_mut(), + )? { + return Poll::Ready(Ok(Event::Incoming { + conn, + packet: packet.to_owned(), + })); + } + } + + Poll::Pending + } +} + +enum Event { + Incoming { + conn: u64, + packet: IpPacket<'static>, + }, + SignalIceCandidate { + conn: u64, + candidate: String, + }, + ConnectionEstablished { + conn: u64, + }, +} + +async fn sleep_until(deadline: Instant) -> Instant { + tokio::time::sleep_until(deadline.into()).await; + + deadline +} diff --git a/rust/connlib/connection/Cargo.toml b/rust/connlib/connection/Cargo.toml new file mode 100644 index 000000000..5c5131f9b --- /dev/null +++ b/rust/connlib/connection/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "firezone-connection" +# mark:automatic-version +version = "1.0.0" +edition = "2021" + +[dependencies] +anyhow = "1" +boringtun = { workspace = true } +firezone-relay = { workspace = true } +pnet_packet = { version = "0.34" } +rand = "0.8" +secrecy = { workspace = true } +str0m = { workspace = true } +thiserror = "1" +tracing = "0.1" diff --git a/rust/connlib/connection/src/index.rs b/rust/connlib/connection/src/index.rs new file mode 100644 index 000000000..3b222aa60 --- /dev/null +++ b/rust/connlib/connection/src/index.rs @@ -0,0 +1,49 @@ +// A basic linear-feedback shift register implemented as xorshift, used to +// distribute peer indexes across the 24-bit address space reserved for peer +// identification. +// The purpose is to obscure the total number of peers using the system and to +// ensure it requires a non-trivial amount of processing power and/or samples +// to guess other peers' indices. Anything more ambitious than this is wasted +// with only 24 bits of space. +pub(crate) struct IndexLfsr { + initial: u32, + lfsr: u32, + mask: u32, +} + +impl IndexLfsr { + /// Generate a random 24-bit nonzero integer + fn random_index() -> u32 { + const LFSR_MAX: u32 = 0xffffff; // 24-bit seed + loop { + let i = rand::random::() & LFSR_MAX; + if i > 0 { + // LFSR seed must be non-zero + break i; + } + } + } + + /// Generate the next value in the pseudorandom sequence + pub(crate) fn next(&mut self) -> u32 { + // 24-bit polynomial for randomness. This is arbitrarily chosen to + // inject bitflips into the value. + const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial + debug_assert_ne!(self.lfsr, 0); + let value = self.lfsr - 1; // lfsr will never have value of 0 + self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY); + assert!(self.lfsr != self.initial, "Too many peers created"); + value ^ self.mask + } +} + +impl Default for IndexLfsr { + fn default() -> Self { + let seed = Self::random_index(); + IndexLfsr { + initial: seed, + lfsr: seed, + mask: Self::random_index(), + } + } +} diff --git a/rust/connlib/connection/src/ip_packet.rs b/rust/connlib/connection/src/ip_packet.rs new file mode 100644 index 000000000..91196b25a --- /dev/null +++ b/rust/connlib/connection/src/ip_packet.rs @@ -0,0 +1,81 @@ +use std::net::IpAddr; + +use pnet_packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, ipv6::Ipv6Packet, Packet}; + +macro_rules! for_both { + ($this:ident, |$name:ident| $body:expr) => { + match $this { + IpPacket::Ipv4($name) => $body, + IpPacket::Ipv6($name) => $body, + } + }; +} + +#[derive(Debug, PartialEq)] +pub enum IpPacket<'a> { + Ipv4(Ipv4Packet<'a>), + Ipv6(Ipv6Packet<'a>), +} + +impl<'a> IpPacket<'a> { + pub fn new(buf: &'a [u8]) -> Option { + match buf[0] >> 4 { + 4 => Some(IpPacket::Ipv4(Ipv4Packet::new(buf)?)), + 6 => Some(IpPacket::Ipv6(Ipv6Packet::new(buf)?)), + _ => None, + } + } + + pub fn to_owned(&self) -> IpPacket<'static> { + match self { + IpPacket::Ipv4(i) => Ipv4Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + IpPacket::Ipv6(i) => Ipv6Packet::owned(i.packet().to_vec()) + .expect("owned packet is still valid") + .into(), + } + } + + pub fn source(&self) -> IpAddr { + for_both!(self, |i| i.get_source().into()) + } + + pub fn destination(&self) -> IpAddr { + for_both!(self, |i| i.get_destination().into()) + } + + pub fn udp_payload(&self) -> &[u8] { + debug_assert_eq!( + match self { + IpPacket::Ipv4(i) => i.get_next_level_protocol(), + IpPacket::Ipv6(i) => i.get_next_header(), + }, + IpNextHeaderProtocols::Udp + ); + + for_both!(self, |i| &i.payload()[8..]) + } +} + +impl<'a> From> for IpPacket<'a> { + fn from(value: Ipv4Packet<'a>) -> Self { + Self::Ipv4(value) + } +} + +impl<'a> From> for IpPacket<'a> { + fn from(value: Ipv6Packet<'a>) -> Self { + Self::Ipv6(value) + } +} + +impl pnet_packet::Packet for IpPacket<'_> { + fn packet(&self) -> &[u8] { + for_both!(self, |i| i.packet()) + } + + fn payload(&self) -> &[u8] { + for_both!(self, |i| i.payload()) + } +} diff --git a/rust/connlib/connection/src/lib.rs b/rust/connlib/connection/src/lib.rs new file mode 100644 index 000000000..313cf1015 --- /dev/null +++ b/rust/connlib/connection/src/lib.rs @@ -0,0 +1,9 @@ +mod index; +mod ip_packet; +mod pool; + +pub use ip_packet::IpPacket; +pub use pool::{ + Answer, ClientConnectionPool, ConnectionPool, Credentials, Error, Event, Offer, + ServerConnectionPool, +}; diff --git a/rust/connlib/connection/src/pool.rs b/rust/connlib/connection/src/pool.rs new file mode 100644 index 000000000..becdb6dbd --- /dev/null +++ b/rust/connlib/connection/src/pool.rs @@ -0,0 +1,632 @@ +use boringtun::noise::{Tunn, TunnResult}; +use boringtun::x25519::PublicKey; +use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; +use core::fmt; +use pnet_packet::ipv4::Ipv4Packet; +use pnet_packet::ipv6::Ipv6Packet; +use pnet_packet::Packet; +use rand::random; +use secrecy::{ExposeSecret, Secret}; +use std::hash::Hash; +use std::marker::PhantomData; +use std::time::{Duration, Instant}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + net::SocketAddr, + sync::Arc, +}; +use str0m::ice::{IceAgent, IceAgentEvent, IceCreds}; +use str0m::net::{Protocol, Receive}; +use str0m::{Candidate, StunMessage}; + +use crate::index::IndexLfsr; +use crate::IpPacket; + +// Note: Taken from boringtun +const HANDSHAKE_RATE_LIMIT: u64 = 100; + +const MAX_UDP_SIZE: usize = (1 << 16) - 1; + +/// Manages a set of wireguard connections for a server. +pub type ServerConnectionPool = ConnectionPool; +/// Manages a set of wireguard connections for a client. +pub type ClientConnectionPool = ConnectionPool; + +pub enum Server {} +pub enum Client {} + +pub struct ConnectionPool { + private_key: StaticSecret, + index: IndexLfsr, + rate_limiter: Arc, + local_interfaces: HashSet, + buffered_transmits: VecDeque, + + next_rate_limiter_reset: Option, + + initial_connections: HashMap, + negotiated_connections: HashMap, + pending_events: VecDeque>, + + buffer: Box<[u8; MAX_UDP_SIZE]>, + + marker: PhantomData, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Unknown interface")] + UnknownInterface, + #[error("Failed to decapsulate: {0:?}")] // TODO: Upstream an std::error::Error impl + Decapsulate(boringtun::noise::errors::WireGuardError), + #[error("Failed to encapsulate: {0:?}")] + Encapsulate(boringtun::noise::errors::WireGuardError), + #[error("Unmatched packet")] + UnmatchedPacket, + #[error("Not connected")] + NotConnected, +} + +impl ConnectionPool +where + TId: Eq + Hash + Copy + fmt::Display, +{ + pub fn new(private_key: StaticSecret) -> Self { + let public_key = &(&private_key).into(); + Self { + private_key, + marker: Default::default(), + index: IndexLfsr::default(), + rate_limiter: Arc::new(RateLimiter::new(public_key, HANDSHAKE_RATE_LIMIT)), + local_interfaces: HashSet::default(), + buffered_transmits: VecDeque::default(), + next_rate_limiter_reset: None, + negotiated_connections: HashMap::default(), + pending_events: VecDeque::default(), + initial_connections: HashMap::default(), + buffer: Box::new([0u8; MAX_UDP_SIZE]), + } + } + + pub fn add_local_interface(&mut self, local_addr: SocketAddr) { + self.local_interfaces.insert(local_addr); + + // TODO: Add host candidate to all existing connections here. + } + + pub fn add_remote_candidate(&mut self, id: TId, candidate: String) { + let candidate = match Candidate::from_sdp_string(&candidate) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to parse candidate: {e}"); + return; + } + }; + + if let Some(agent) = self.agent_mut(id) { + agent.add_remote_candidate(candidate); + } + } + + /// Decapsulate an incoming packet. + /// + /// # Returns + /// + /// - `Ok(None)` if the packet was handled internally, for example, a response from a TURN server. + /// - `Ok(Some)` if the packet was an encrypted wireguard packet from a peer. + /// The `Option` contains the connection on which the packet was decrypted. + pub fn decapsulate<'s>( + &mut self, + local: SocketAddr, + from: SocketAddr, + packet: &[u8], + now: Instant, + buffer: &'s mut [u8], + ) -> Result)>, Error> { + if !self.local_interfaces.contains(&local) { + return Err(Error::UnknownInterface); + } + + // TODO: First thing we need to check if the message is from one of our STUN / TURN servers AND it is a STUN message (starts with 0x03) + // ... + // ... + + // Next: If we can parse the message as a STUN message, cycle through all agents to check which one it is for. + if let Ok(stun_message) = StunMessage::parse(packet) { + for (_, conn) in self.initial_connections.iter_mut() { + // TODO: `accepts_message` cannot demultiplexing multiple connections until https://github.com/algesten/str0m/pull/418 is merged. + if conn.agent.accepts_message(&stun_message) { + conn.agent.handle_receive( + now, + Receive { + proto: Protocol::Udp, + source: from, + destination: local, + contents: str0m::net::DatagramRecv::Stun(stun_message), + }, + ); + return Ok(None); + } + } + + for (_, conn) in self.negotiated_connections.iter_mut() { + // Would the ICE agent of this connection like to handle the packet? + if conn.agent.accepts_message(&stun_message) { + conn.agent.handle_receive( + now, + Receive { + proto: Protocol::Udp, + source: from, + destination: local, + contents: str0m::net::DatagramRecv::Stun(stun_message), + }, + ); + return Ok(None); + } + } + } + + for (id, conn) in self.negotiated_connections.iter_mut() { + if !conn.accepts(from) { + continue; + } + + // TODO: I think eventually, here is where we'd unwrap a channel data message. + + return match conn.tunnel.decapsulate(None, packet, buffer) { + TunnResult::Done => Ok(None), + TunnResult::Err(e) => Err(Error::Decapsulate(e)), + + // For WriteToTunnel{V4,V6}, boringtun returns the source IP of the packet that was tunneled to us. + // I am guessing this was done for convenience reasons. + // In our API, we parse the packets directly as an IpPacket. + // Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition. + TunnResult::WriteToTunnelV4(packet, ip) => { + let ipv4_packet = Ipv4Packet::new(packet).expect("boringtun verifies validity"); + debug_assert_eq!(ipv4_packet.get_source(), ip); + + Ok(Some((*id, ipv4_packet.into()))) + } + TunnResult::WriteToTunnelV6(packet, ip) => { + let ipv6_packet = Ipv6Packet::new(packet).expect("boringtun verifies validity"); + debug_assert_eq!(ipv6_packet.get_source(), ip); + + Ok(Some((*id, ipv6_packet.into()))) + } + + // During normal operation, i.e. when the tunnel is active, decapsulating a packet straight yields the decrypted packet. + // However, in case `Tunn` has buffered packets, they may be returned here instead. + // This should be fairly rare which is why we just allocate these and return them from `poll_transmit` instead. + // Overall, this results in a much nicer API for our caller and should not affect performance. + TunnResult::WriteToNetwork(bytes) => { + self.buffered_transmits.push_back(Transmit { + dst: from, + payload: bytes.to_vec(), + }); + + while let TunnResult::WriteToNetwork(packet) = + conn.tunnel + .decapsulate(None, &[], self.buffer.as_mut_slice()) + { + self.buffered_transmits.push_back(Transmit { + dst: from, + payload: packet.to_vec(), + }); + } + + Ok(None) + } + }; + } + + Err(Error::UnmatchedPacket) + } + + /// Encapsulate an outgoing IP packet. + /// + /// Wireguard is an IP tunnel, so we "enforce" that only IP packets are sent through it. + /// We say "enforce" an [`IpPacket`] can be created from an (almost) arbitrary byte buffer at virtually no cost. + /// Nevertheless, using [`IpPacket`] in our API has good documentation value. + pub fn encapsulate<'s>( + &'s mut self, + connection: TId, + packet: IpPacket<'_>, + ) -> Result, Error> { + // TODO: We need to return, which local socket to use to send the data. + let conn = self + .negotiated_connections + .get_mut(&connection) + .ok_or(Error::NotConnected)?; + + let remote_socket = conn.remote_socket.ok_or(Error::NotConnected)?; + + // TODO: If we are connected via TURN, wrap in data channel message here. + + match conn + .tunnel + .encapsulate(packet.packet(), self.buffer.as_mut()) + { + TunnResult::Done => Ok(None), + TunnResult::Err(e) => Err(Error::Encapsulate(e)), + TunnResult::WriteToNetwork(packet) => Ok(Some((remote_socket, packet))), + TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { + unreachable!("never returned from encapsulate") + } + } + } + + /// Returns a pending [`Event`] from the pool. + pub fn poll_event(&mut self) -> Option> { + for (id, conn) in self.negotiated_connections.iter_mut() { + while let Some(event) = conn.agent.poll_event() { + match event { + IceAgentEvent::DiscoveredRecv { source, .. } => { + conn.possible_sockets.insert(source); + // TODO: Here is where we'd allocate channels. + } + IceAgentEvent::IceRestart(_) => {} + IceAgentEvent::IceConnectionStateChange(_) => {} + IceAgentEvent::NominatedSend { destination, .. } => { + let old = conn.remote_socket; + + conn.remote_socket = Some(destination); + + match old { + Some(old) => { + tracing::info!(%id, new = %destination, %old, "Migrating connection to peer") + } + None => { + tracing::info!(%id, %destination, "Connected to peer"); + return Some(Event::ConnectionEstablished(*id)); + } + } + } + } + } + } + + self.pending_events.pop_front() + } + + /// Returns, when [`ConnectionPool::handle_timeout`] should be called next. + /// + /// This function only takes `&mut self` because it caches certain computations internally. + /// The returned timestamp will **not** change unless other state is modified. + pub fn poll_timeout(&mut self) -> Option { + let mut connection_timeout = None; + + // TODO: Do we need to poll ice agents of initial connections?? + + for c in self.negotiated_connections.values_mut() { + connection_timeout = earliest(connection_timeout, c.poll_timeout()); + } + + earliest(connection_timeout, self.next_rate_limiter_reset) + } + + /// Advances time within the [`ConnectionPool`]. + /// + /// This advances time within the ICE agent, updates timers within all wireguard connections as well as resets wireguard's rate limiter (if necessary). + pub fn handle_timeout(&mut self, now: Instant) { + for c in self.negotiated_connections.values_mut() { + self.buffered_transmits.extend(c.handle_timeout(now)); + } + + let next_reset = *self.next_rate_limiter_reset.get_or_insert(now); + + if now >= next_reset { + self.rate_limiter.reset_count(); + self.next_rate_limiter_reset = Some(now + Duration::from_secs(1)); + } + } + + /// Returns buffered data that needs to be sent on the socket. + pub fn poll_transmit(&mut self) -> Option { + for conn in self.initial_connections.values_mut() { + if let Some(transmit) = conn.agent.poll_transmit() { + return Some(Transmit { + dst: transmit.destination, + payload: transmit.contents.into(), + }); + } + } + + for conn in self.negotiated_connections.values_mut() { + if let Some(transmit) = conn.agent.poll_transmit() { + return Some(Transmit { + dst: transmit.destination, + payload: transmit.contents.into(), + }); + } + } + + self.buffered_transmits.pop_front() + } + + fn agent_mut(&mut self, id: TId) -> Option<&mut IceAgent> { + let maybe_initial_connection = self.initial_connections.get_mut(&id).map(|i| &mut i.agent); + let maybe_established_connection = self + .negotiated_connections + .get_mut(&id) + .map(|c| &mut c.agent); + + maybe_initial_connection.or(maybe_established_connection) + } +} + +impl ConnectionPool +where + TId: Eq + Hash + Copy, +{ + /// Create a new connection indexed by the given ID. + /// + /// Out of all configured STUN and TURN servers, the connection will only use the ones provided here. + /// The returned [`Offer`] must be passed to the remote via a signalling channel. + pub fn new_connection( + &mut self, + id: TId, + allowed_stun_servers: Vec, + allowed_turn_servers: Vec, + ) -> Offer { + let mut agent = IceAgent::new(); + agent.set_controlling(true); + + for local in self.local_interfaces.iter().copied() { + let candidate = match Candidate::host(local, Protocol::Udp) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to generate host candidate from addr: {e}"); + continue; + } + }; + + if agent.add_local_candidate(candidate.clone()) { + self.pending_events.push_back(Event::SignalIceCandidate { + connection: id, + candidate: candidate.to_sdp_string(), + }); + } + } + + let session_key = Secret::new(random()); + let ice_creds = agent.local_credentials(); + + let params = Offer { + session_key: session_key.clone(), + credentials: Credentials { + username: ice_creds.ufrag.clone(), + password: ice_creds.pass.clone(), + }, + }; + + self.initial_connections.insert( + id, + InitialConnection { + agent, + session_key, + stun_servers: allowed_stun_servers, + turn_servers: allowed_turn_servers, + }, + ); + + params + } + + /// Accept an [`Answer`] from the remote for a connection previously created via [`ConnectionPool::new_connection`]. + pub fn accept_answer(&mut self, id: TId, remote: PublicKey, answer: Answer) { + let Some(initial) = self.initial_connections.remove(&id) else { + return; // TODO: Better error handling + }; + + let mut agent = initial.agent; + agent.set_remote_credentials(IceCreds { + ufrag: answer.credentials.username, + pass: answer.credentials.password, + }); + + self.negotiated_connections.insert( + id, + Connection { + agent, + tunnel: Tunn::new( + self.private_key.clone(), + remote, + Some(*initial.session_key.expose_secret()), + None, + self.index.next(), + Some(self.rate_limiter.clone()), + ), + _stun_servers: initial.stun_servers, + _turn_servers: initial.turn_servers, + next_timer_update: None, + remote_socket: None, + possible_sockets: HashSet::default(), + }, + ); + } +} + +impl ConnectionPool +where + TId: Eq + Hash + Copy, +{ + pub fn accept_connection( + &mut self, + id: TId, + offer: Offer, + remote: PublicKey, + allowed_stun_servers: Vec, + allowed_turn_servers: Vec, + ) -> Answer { + let mut agent = IceAgent::new(); + agent.set_controlling(false); + agent.set_remote_credentials(IceCreds { + ufrag: offer.credentials.username, + pass: offer.credentials.password, + }); + let answer = Answer { + credentials: Credentials { + username: agent.local_credentials().ufrag.clone(), + password: agent.local_credentials().pass.clone(), + }, + }; + + for local in self.local_interfaces.iter().copied() { + let candidate = match Candidate::host(local, Protocol::Udp) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Failed to generate host candidate from addr: {e}"); + continue; + } + }; + + if agent.add_local_candidate(candidate.clone()) { + self.pending_events.push_back(Event::SignalIceCandidate { + connection: id, + candidate: candidate.to_sdp_string(), + }); + } + } + + self.negotiated_connections.insert( + id, + Connection { + agent, + tunnel: Tunn::new( + self.private_key.clone(), + remote, + Some(*offer.session_key.expose_secret()), + None, + self.index.next(), + Some(self.rate_limiter.clone()), + ), + _stun_servers: allowed_stun_servers, + _turn_servers: allowed_turn_servers, + next_timer_update: None, + remote_socket: None, + possible_sockets: HashSet::default(), + }, + ); + + answer + } +} + +pub struct Offer { + /// The Wireguard session key for a connection. + pub session_key: Secret<[u8; 32]>, + pub credentials: Credentials, +} + +pub struct Answer { + pub credentials: Credentials, +} + +pub struct Credentials { + /// The ICE username (ufrag). + pub username: String, + /// The ICE password. + pub password: String, +} + +pub enum Event { + /// Signal the ICE candidate to the remote via the signalling channel. + /// + /// Candidates are in SDP format although this may change and should be considered an implementation detail of the application. + SignalIceCandidate { + connection: TId, + candidate: String, + }, + ConnectionEstablished(TId), +} + +pub struct Transmit { + pub dst: SocketAddr, + pub payload: Vec, +} + +pub struct InitialConnection { + agent: IceAgent, + session_key: Secret<[u8; 32]>, + stun_servers: Vec, + turn_servers: Vec, +} + +struct Connection { + agent: IceAgent, + + tunnel: Tunn, + next_timer_update: Option, + + // When this is `Some`, we are connected. + remote_socket: Option, + // Socket addresses from which we might receive data (even before we are connected). + possible_sockets: HashSet, + + _stun_servers: Vec, + _turn_servers: Vec, +} + +impl Connection { + /// Checks if we want to accept a packet from a certain address. + /// + /// Whilst we establish connections, we may see traffic from a certain address, prior to the negotiation being fully complete. + /// We already want to accept that traffic and not throw it away. + fn accepts(&self, addr: SocketAddr) -> bool { + let from_connected_remote = self.remote_socket.is_some_and(|r| r == addr); + let from_possible_remote = self.possible_sockets.contains(&addr); + + from_connected_remote || from_possible_remote + } + + fn poll_timeout(&mut self) -> Option { + let agent_timeout = self.agent.poll_timeout(); + let next_wg_timer = self.next_timer_update; + + earliest(agent_timeout, next_wg_timer) + } + + fn handle_timeout(&mut self, now: Instant) -> Option { + self.agent.handle_timeout(now); + + let remote = self.remote_socket?; + let next_timer_update = self.next_timer_update?; + + if now >= next_timer_update { + self.next_timer_update = Some(now + Duration::from_nanos(1)); + + /// [`boringtun`] requires us to pass buffers in where it can construct its packets. + /// + /// When updating the timers, the largest packet that we may have to send is `148` bytes as per `HANDSHAKE_INIT_SZ` constant in [`boringtun`]. + const MAX_SCRATCH_SPACE: usize = 148; + + let mut buf = [0u8; MAX_SCRATCH_SPACE]; + + match self.tunnel.update_timers(&mut buf) { + TunnResult::Done => {} + TunnResult::Err(e) => { + // TODO: Handle this error. I think it can only be an expired connection so we should return a very specific error to the caller to make this easy to handle! + panic!("{e:?}") + } + TunnResult::WriteToNetwork(b) => { + return Some(Transmit { + dst: remote, + payload: b.to_vec(), + }); + } + _ => panic!("Unexpected result from update_timers"), + }; + } + + None + } +} + +fn earliest(left: Option, right: Option) -> Option { + match (left, right) { + (None, None) => None, + (Some(left), Some(right)) => Some(std::cmp::min(left, right)), + (Some(left), None) => Some(left), + (None, Some(right)) => Some(right), + } +}