diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4679e9083..9919faede 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -151,7 +151,10 @@ jobs: strategy: fail-fast: false matrix: - file: ['docker-compose.lan.yml'] + file: [ + 'docker-compose.lan.yml', + 'docker-compose.wan.yml' + ] steps: - uses: actions/checkout@v4 - uses: ./.github/actions/gcp-docker-login @@ -161,7 +164,7 @@ jobs: - name: Run ${{ matrix.file }} test run: | sudo sysctl -w vm.overcommit_memory=1 - docker compose -f rust/connection-tests/${{ matrix.file }} up --exit-code-from dialer --abort-on-container-exit + timeout 600 docker compose -f rust/connection-tests/${{ matrix.file }} up --exit-code-from dialer --abort-on-container-exit integration-tests: needs: build-images diff --git a/NOTICE.txt b/NOTICE.txt index 4072f081f..5999b425c 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -157,3 +157,27 @@ Prebuilt Binaries License details, features, specifications, capabilities, functions, licensing terms, release dates, APIs, ABIs, general availability, or other characteristics of the Software. + +=== + +Portions of this product (`rust/connection-tests`) contain derivative work from https://github.com/libp2p/test-plans. Its license is included below: + +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 222724587..3e641189b 100755 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1975,11 +1975,13 @@ version = "1.0.0" dependencies = [ "anyhow", "boringtun", + "bytecodec", "firezone-relay", "pnet_packet", "rand 0.8.5", "secrecy", "str0m", + "stun_codec", "thiserror", "tracing", ] diff --git a/rust/Dockerfile b/rust/Dockerfile index 563bc3ce3..3107df8d0 100644 --- a/rust/Dockerfile +++ b/rust/Dockerfile @@ -124,7 +124,7 @@ CMD $PACKAGE FROM runtime AS debug RUN set -xe \ - && apk add --no-cache iperf3 jq + && apk add --no-cache iperf3 bind-tools iproute2 jq ARG TARGET COPY --from=builder /build/target/${TARGET}/debug/${PACKAGE} . diff --git a/rust/connection-tests/docker-compose.wan.yml b/rust/connection-tests/docker-compose.wan.yml new file mode 100644 index 000000000..8793e2f03 --- /dev/null +++ b/rust/connection-tests/docker-compose.wan.yml @@ -0,0 +1,147 @@ +version: "3.8" +name: wan-hp-integration-test + +services: + dialer: + build: + target: debug + context: .. + args: + PACKAGE: firezone-connection-tests + cache_from: + - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main + image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main} + environment: + ROLE: "dialer" + cap_add: + - NET_ADMIN + entrypoint: /bin/sh + command: + - -c + - | + set -ex + + ROUTER_IP=$$(dig +short dialer_router) + INTERNET_SUBNET=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/networks/wan-hp-integration-test_wan | jq -r '.IPAM.Config[0].Subnet') + + ip route add $$INTERNET_SUBNET via $$ROUTER_IP dev eth0 + + export STUN_SERVER=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-relay-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress') + export REDIS_HOST=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-redis-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress') + + firezone-connection-tests + depends_on: + - dialer_router + - redis + networks: + - lan1 + volumes: + - /var/run/docker.sock:/var/run/docker.sock + + dialer_router: + init: true + build: + context: ./router + cap_add: + - NET_ADMIN + networks: + - lan1 + - wan + + listener: + build: + target: debug + context: .. + args: + PACKAGE: firezone-connection-tests + cache_from: + - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main + image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main} + init: true + environment: + ROLE: "listener" + entrypoint: /bin/sh + command: + - -c + - | + set -ex + + ROUTER_IP=$$(dig +short listener_router) + INTERNET_SUBNET=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/networks/wan-hp-integration-test_wan | jq -r '.IPAM.Config[0].Subnet') + + ip route add $$INTERNET_SUBNET via $$ROUTER_IP dev eth0 + + export STUN_SERVER=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-relay-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress') + export REDIS_HOST=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-redis-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress') + + firezone-connection-tests + cap_add: + - NET_ADMIN + depends_on: + - listener_router + - redis + networks: + - lan2 + volumes: + - /var/run/docker.sock:/var/run/docker.sock + + listener_router: + init: true + build: + context: ./router + cap_add: + - NET_ADMIN + networks: + - lan2 + - wan + + relay: + environment: + LOWEST_PORT: 55555 + HIGHEST_PORT: 55666 + RUST_LOG: "debug" + RUST_BACKTRACE: 1 + build: + target: debug + context: .. + cache_from: + - type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/relay:main + args: + PACKAGE: firezone-relay + image: us-east1-docker.pkg.dev/firezone-staging/firezone/relay:${VERSION:-main} + init: true + healthcheck: + test: ["CMD-SHELL", "lsof -i UDP | grep firezone-relay"] + start_period: 20s + interval: 30s + retries: 5 + timeout: 5s + entrypoint: /bin/sh + command: + - -c + - | + set -ex; + export PUBLIC_IP4_ADDR=$(ip -json addr show eth0 | jq '.[0].addr_info[0].local' -r) + + firezone-relay + ports: + # XXX: Only 111 ports are used for local dev / testing because Docker Desktop + # allocates a userland proxy process for each forwarded port X_X. + # + # Large ranges here will bring your machine to its knees. + - "55555-55666:55555-55666/udp" + - 3478:3478/udp + networks: + - wan + + redis: + image: "redis:7-alpine" + healthcheck: + test: ["CMD-SHELL", "echo 'ready';"] + networks: + - wan + +networks: + lan1: + lan2: + wan: diff --git a/rust/connection-tests/router/Dockerfile b/rust/connection-tests/router/Dockerfile new file mode 100644 index 000000000..92e4570d5 --- /dev/null +++ b/rust/connection-tests/router/Dockerfile @@ -0,0 +1,11 @@ +FROM debian:12-slim + +ARG DEBIAN_FRONTEND=noninteractive +RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get -y install iproute2 nftables conntrack + +COPY *.sh /scripts/ +RUN chmod +x /scripts/*.sh + +HEALTHCHECK CMD [ "sh", "-c", "test $(cat /tmp/setup_done) = 1" ] + +ENTRYPOINT ["./scripts/run.sh"] diff --git a/rust/connection-tests/router/README.md b/rust/connection-tests/router/README.md new file mode 100644 index 000000000..38ec6b6ff --- /dev/null +++ b/rust/connection-tests/router/README.md @@ -0,0 +1,18 @@ +# Router + +This directory contains a Debian-based router implemented on top of nftables. + +It expects to be run with two network interfaces: + +- `eth1`: The "external" interface. +- `eth0`: The "internal" interface. + +The order of these interfaces depends on lexical sorting the docker networks names. + +The order of these is important. +The router cannot possibly know which one is which and thus assumes that `eth0` is the external one and `eth1` the internal one. +The firewall is set up to take incoming traffic on `eth1` and forward + masquerade it to `eth0`. + +It also expects an env variable `DELAY_MS` to be set and will apply this delay as part of the routing process[^1]. + +[^1]: This is done via `tc qdisc` which only works for egress traffic. To ensure the delay applies in both directions, we divide it by 2 and apply it on both interfaces. diff --git a/rust/connection-tests/router/run.sh b/rust/connection-tests/router/run.sh new file mode 100644 index 000000000..edbf370ae --- /dev/null +++ b/rust/connection-tests/router/run.sh @@ -0,0 +1,18 @@ +#!/bin/sh + +set -ex + +# Set up NAT +nft add table ip nat +nft add chain ip nat postrouting { type nat hook postrouting priority 100 \; } +nft add rule ip nat postrouting masquerade + +# Assumption after a long debugging session involving Gabi, Jamil and Thomas: +# On the same machine, the kernel cannot differentiate between incoming and outgoing packets across different network namespaces within the firewall and NAT mapping table. +# As a result, even UDP hole-punching is time-sensitive and we thus need to make sure that we first send a packet _out_ through the router before the other one is incoming. +# To achieve this, we set an absurdly high latency of 300ms for the WAN network. +tc qdisc add dev eth1 root netem delay 300ms + +echo "1" > /tmp/setup_done # This will be checked by our docker HEALTHCHECK + +conntrack --event --proto UDP --output timestamp # Display a real-time log of NAT events in the kernel. diff --git a/rust/connection-tests/src/main.rs b/rust/connection-tests/src/main.rs index 5b8dd6f75..ffd36faf5 100644 --- a/rust/connection-tests/src/main.rs +++ b/rust/connection-tests/src/main.rs @@ -1,6 +1,7 @@ use std::{ + collections::HashSet, future::poll_fn, - net::Ipv4Addr, + net::{IpAddr, Ipv4Addr, SocketAddr}, str::FromStr, task::{Context, Poll}, time::Instant, @@ -42,6 +43,13 @@ async fn main() -> Result<()> { .ip .to_std(); + let stun_server = std::env::var("STUN_SERVER") + .ok() + .map(|a| a.parse::()) + .transpose() + .context("Failed to parse `STUN_SERVER`")? + .map(|ip| SocketAddr::new(ip, 3478)); + tracing::info!(%listen_addr); let redis_host = std::env::var("REDIS_HOST").context("Missing REDIS_HOST env var")?; @@ -63,7 +71,8 @@ async fn main() -> Result<()> { let mut pool = ClientConnectionPool::::new(private_key); pool.add_local_interface(socket_addr); - let offer = pool.new_connection(1, vec![], vec![]); + let offer = + pool.new_connection(1, stun_server.into_iter().collect(), HashSet::default()); redis_connection .rpush( @@ -124,6 +133,9 @@ async fn main() -> Result<()> { start = Instant::now(); eventloop.send_to(conn, ip4_udp_ping_packet(source, dst, &ping_body).into())?; } + Event::ConnectionFailed { conn } => { + anyhow::bail!("Failed to establish connection: {conn}"); + } } } @@ -156,8 +168,8 @@ async fn main() -> Result<()> { }, }, offer.public_key.into(), - vec![], - vec![], + stun_server.into_iter().collect(), + HashSet::default(), ); redis_connection @@ -188,6 +200,9 @@ async fn main() -> Result<()> { .context("Failed to push candidate")?; } Event::ConnectionEstablished { .. } => { } + Event::ConnectionFailed { conn } => { + anyhow::bail!("Failed to establish connection: {conn}"); + } } } @@ -340,6 +355,9 @@ impl Eventloop { Some(firezone_connection::Event::ConnectionEstablished(conn)) => { return Poll::Ready(Ok(Event::ConnectionEstablished { conn })) } + Some(firezone_connection::Event::ConnectionFailed(conn)) => { + return Poll::Ready(Ok(Event::ConnectionFailed { conn })) + } None => {} } @@ -389,6 +407,9 @@ enum Event { ConnectionEstablished { conn: u64, }, + ConnectionFailed { + conn: u64, + }, } async fn sleep_until(deadline: Instant) -> Instant { diff --git a/rust/connlib/connection/Cargo.toml b/rust/connlib/connection/Cargo.toml index 5c5131f9b..899f8291b 100644 --- a/rust/connlib/connection/Cargo.toml +++ b/rust/connlib/connection/Cargo.toml @@ -14,3 +14,5 @@ secrecy = { workspace = true } str0m = { workspace = true } thiserror = "1" tracing = "0.1" +stun_codec = "0.3.4" +bytecodec = "0.4.15" diff --git a/rust/connlib/connection/src/lib.rs b/rust/connlib/connection/src/lib.rs index 313cf1015..f3aa7c471 100644 --- a/rust/connlib/connection/src/lib.rs +++ b/rust/connlib/connection/src/lib.rs @@ -1,6 +1,7 @@ mod index; mod ip_packet; mod pool; +mod stun_binding; pub use ip_packet::IpPacket; pub use pool::{ diff --git a/rust/connlib/connection/src/pool.rs b/rust/connlib/connection/src/pool.rs index 6583c5a9a..50be51d94 100644 --- a/rust/connlib/connection/src/pool.rs +++ b/rust/connlib/connection/src/pool.rs @@ -17,9 +17,10 @@ use std::{ }; use str0m::ice::{IceAgent, IceAgentEvent, IceCreds}; use str0m::net::{Protocol, Receive}; -use str0m::{Candidate, StunMessage}; +use str0m::{Candidate, IceConnectionState, StunMessage}; use crate::index::IndexLfsr; +use crate::stun_binding::StunBinding; use crate::IpPacket; // Note: Taken from boringtun @@ -44,6 +45,8 @@ pub struct ConnectionPool { next_rate_limiter_reset: Option, + stun_servers: HashMap, + initial_connections: HashMap, negotiated_connections: HashMap, pending_events: VecDeque>, @@ -85,6 +88,7 @@ where pending_events: VecDeque::default(), initial_connections: HashMap::default(), buffer: Box::new([0u8; MAX_UDP_SIZE]), + stun_servers: HashMap::default(), } } @@ -127,9 +131,20 @@ where return Err(Error::UnknownInterface); } - // TODO: First thing we need to check if the message is from one of our STUN / TURN servers AND it is a STUN message (starts with 0x03) - // ... - // ... + // First, check if a `StunBinding` wants the packet + if let Some(binding) = self.stun_servers.get_mut(&from) { + if binding.handle_input(from, packet, now) { + // If it handled the packet, drain its events to ensure we update the candidates of all connections. + drain_binding_events( + from, + binding, + &mut self.initial_connections, + &mut self.negotiated_connections, + &mut self.pending_events, + ); + return Ok(None); + } + } // Next: If we can parse the message as a STUN message, cycle through all agents to check which one it is for. if let Ok(stun_message) = StunMessage::parse(packet) { @@ -270,8 +285,9 @@ where conn.possible_sockets.insert(source); // TODO: Here is where we'd allocate channels. } - IceAgentEvent::IceRestart(_) => {} - IceAgentEvent::IceConnectionStateChange(_) => {} + IceAgentEvent::IceConnectionStateChange(IceConnectionState::Disconnected) => { + return Some(Event::ConnectionFailed(*id)); + } IceAgentEvent::NominatedSend { destination, .. } => match conn.remote_socket { Some(old) if old != destination => { tracing::info!(%id, new = %destination, %old, "Migrating connection to peer"); @@ -285,6 +301,7 @@ where } _ => {} }, + _ => {} } } } @@ -304,6 +321,9 @@ where for c in self.negotiated_connections.values_mut() { connection_timeout = earliest(connection_timeout, c.poll_timeout()); } + for b in self.stun_servers.values_mut() { + connection_timeout = earliest(connection_timeout, b.poll_timeout()); + } earliest(connection_timeout, self.next_rate_limiter_reset) } @@ -316,6 +336,10 @@ where self.buffered_transmits.extend(c.handle_timeout(now)); } + for binding in self.stun_servers.values_mut() { + binding.handle_timeout(now); + } + let next_reset = *self.next_rate_limiter_reset.get_or_insert(now); if now >= next_reset { @@ -335,6 +359,12 @@ where } } + for binding in self.stun_servers.values_mut() { + if let Some(transmit) = binding.poll_transmit() { + return Some(transmit); + } + } + for conn in self.negotiated_connections.values_mut() { if let Some(transmit) = conn.agent.poll_transmit() { return Some(Transmit { @@ -356,25 +386,23 @@ where maybe_initial_connection.or(maybe_established_connection) } -} -impl ConnectionPool -where - TId: Eq + Hash + Copy, -{ - /// Create a new connection indexed by the given ID. - /// - /// Out of all configured STUN and TURN servers, the connection will only use the ones provided here. - /// The returned [`Offer`] must be passed to the remote via a signalling channel. - pub fn new_connection( + fn upsert_stun_servers(&mut self, servers: &HashSet) { + for server in servers { + if !self.stun_servers.contains_key(server) { + tracing::debug!(address = %server, "Adding new STUN server"); + + self.stun_servers.insert(*server, StunBinding::new(*server)); + } + } + } + + fn seed_agent_with_local_candidates( &mut self, - id: TId, - allowed_stun_servers: Vec, - allowed_turn_servers: Vec, - ) -> Offer { - let mut agent = IceAgent::new(); - agent.set_controlling(true); - + connection: TId, + agent: &mut IceAgent, + allowed_stun_servers: &HashSet, + ) { for local in self.local_interfaces.iter().copied() { let candidate = match Candidate::host(local, Protocol::Udp) { Ok(c) => c, @@ -384,14 +412,52 @@ where } }; - if agent.add_local_candidate(candidate.clone()) { - self.pending_events.push_back(Event::SignalIceCandidate { - connection: id, - candidate: candidate.to_sdp_string(), - }); - } + add_local_candidate( + connection, + agent, + candidate.clone(), + &mut self.pending_events, + ); } + for candidate in self.stun_servers.iter().filter_map(|(server, binding)| { + let candidate = allowed_stun_servers + .contains(server) + .then(|| binding.candidate())??; + + Some(candidate) + }) { + add_local_candidate( + connection, + agent, + candidate.clone(), + &mut self.pending_events, + ); + } + } +} + +impl ConnectionPool +where + TId: Eq + Hash + Copy + fmt::Display, +{ + /// Create a new connection indexed by the given ID. + /// + /// Out of all configured STUN and TURN servers, the connection will only use the ones provided here. + /// The returned [`Offer`] must be passed to the remote via a signalling channel. + pub fn new_connection( + &mut self, + id: TId, + allowed_stun_servers: HashSet, + allowed_turn_servers: HashSet, + ) -> Offer { + self.upsert_stun_servers(&allowed_stun_servers); + + let mut agent = IceAgent::new(); + agent.set_controlling(true); + + self.seed_agent_with_local_candidates(id, &mut agent, &allowed_stun_servers); + let session_key = Secret::new(random()); let ice_creds = agent.local_credentials(); @@ -440,7 +506,7 @@ where self.index.next(), Some(self.rate_limiter.clone()), ), - _stun_servers: initial.stun_servers, + stun_servers: initial.stun_servers, _turn_servers: initial.turn_servers, next_timer_update: None, remote_socket: None, @@ -452,16 +518,18 @@ where impl ConnectionPool where - TId: Eq + Hash + Copy, + TId: Eq + Hash + Copy + fmt::Display, { pub fn accept_connection( &mut self, id: TId, offer: Offer, remote: PublicKey, - allowed_stun_servers: Vec, - allowed_turn_servers: Vec, + allowed_stun_servers: HashSet, + allowed_turn_servers: HashSet, ) -> Answer { + self.upsert_stun_servers(&allowed_stun_servers); + let mut agent = IceAgent::new(); agent.set_controlling(false); agent.set_remote_credentials(IceCreds { @@ -475,22 +543,7 @@ where }, }; - for local in self.local_interfaces.iter().copied() { - let candidate = match Candidate::host(local, Protocol::Udp) { - Ok(c) => c, - Err(e) => { - tracing::debug!("Failed to generate host candidate from addr: {e}"); - continue; - } - }; - - if agent.add_local_candidate(candidate.clone()) { - self.pending_events.push_back(Event::SignalIceCandidate { - connection: id, - candidate: candidate.to_sdp_string(), - }); - } - } + self.seed_agent_with_local_candidates(id, &mut agent, &allowed_stun_servers); self.negotiated_connections.insert( id, @@ -504,7 +557,7 @@ where self.index.next(), Some(self.rate_limiter.clone()), ), - _stun_servers: allowed_stun_servers, + stun_servers: allowed_stun_servers, _turn_servers: allowed_turn_servers, next_timer_update: None, remote_socket: None, @@ -516,6 +569,55 @@ where } } +fn drain_binding_events( + server: SocketAddr, + binding: &mut StunBinding, + initial_connections: &mut HashMap, + negotiated_connections: &mut HashMap, + pending_events: &mut VecDeque>, +) where + TId: Copy, +{ + while let Some(event) = binding.poll_event() { + match event { + crate::stun_binding::Event::NewServerReflexiveCandidate { candidate } => { + // TODO: Reduce duplication between initial and negotiated connections + for (id, c) in initial_connections.iter_mut() { + if !c.stun_servers.contains(&server) { + continue; + } + + add_local_candidate(*id, &mut c.agent, candidate.clone(), pending_events); + } + + for (id, c) in negotiated_connections.iter_mut() { + if !c.stun_servers.contains(&server) { + continue; + } + + add_local_candidate(*id, &mut c.agent, candidate.clone(), pending_events); + } + } + }; + } +} + +fn add_local_candidate( + id: TId, + agent: &mut IceAgent, + candidate: Candidate, + pending_events: &mut VecDeque>, +) { + let is_new = agent.add_local_candidate(candidate.clone()); + + if is_new { + pending_events.push_back(Event::SignalIceCandidate { + connection: id, + candidate: candidate.to_sdp_string(), + }) + } +} + pub struct Offer { /// The Wireguard session key for a connection. pub session_key: Secret<[u8; 32]>, @@ -542,8 +644,14 @@ pub enum Event { candidate: String, }, ConnectionEstablished(TId), + + /// We tested all candidates and failed to establish a connection. + /// + /// This condition will not resolve unless more candidates are added or the network conditions change. + ConnectionFailed(TId), } +#[derive(Debug)] pub struct Transmit { pub dst: SocketAddr, pub payload: Vec, @@ -552,8 +660,8 @@ pub struct Transmit { pub struct InitialConnection { agent: IceAgent, session_key: Secret<[u8; 32]>, - stun_servers: Vec, - turn_servers: Vec, + stun_servers: HashSet, + turn_servers: HashSet, } struct Connection { @@ -567,8 +675,8 @@ struct Connection { // Socket addresses from which we might receive data (even before we are connected). possible_sockets: HashSet, - _stun_servers: Vec, - _turn_servers: Vec, + stun_servers: HashSet, + _turn_servers: HashSet, } impl Connection { diff --git a/rust/connlib/connection/src/stun_binding.rs b/rust/connlib/connection/src/stun_binding.rs new file mode 100644 index 000000000..7b8dc04f3 --- /dev/null +++ b/rust/connlib/connection/src/stun_binding.rs @@ -0,0 +1,321 @@ +use crate::pool::Transmit; +use bytecodec::{DecodeExt, EncodeExt}; +use std::{ + collections::VecDeque, + net::SocketAddr, + time::{Duration, Instant}, +}; +use str0m::{net::Protocol, Candidate}; +use stun_codec::{ + rfc5389::{self, attributes::XorMappedAddress}, + Attribute, Message, TransactionId, +}; + +const STUN_TIMEOUT: Duration = Duration::from_secs(5); +const STUN_REFRESH: Duration = Duration::from_secs(5 * 60); + +/// A SANS-IO state machine that obtains a server-reflexive candidate from the configured STUN server. +#[derive(Debug)] +pub struct StunBinding { + server: SocketAddr, + last_candidate: Option, + state: State, + last_now: Option, + + buffered_transmits: VecDeque, + buffered_events: VecDeque, +} + +impl StunBinding { + pub fn new(server: SocketAddr) -> Self { + Self { + server, + last_candidate: None, + state: State::Initial, + last_now: None, + buffered_transmits: Default::default(), + buffered_events: Default::default(), + } + } + + pub fn candidate(&self) -> Option { + self.last_candidate.clone() + } + + pub fn handle_input(&mut self, from: SocketAddr, packet: &[u8], now: Instant) -> bool { + self.last_now = Some(now); // TODO: Do we need to do any other updates here? + + if from != self.server { + return false; + } + + let Ok(Ok(message)) = + stun_codec::MessageDecoder::::default() + .decode_from_bytes(packet) + else { + return false; + }; + + match self.state { + State::SentRequest { id, .. } if id == message.transaction_id() => { + self.state = State::ReceivedResponse { at: now } + } + _ => { + return false; + } + } + + let Some(mapped_address) = message.get_attribute::() else { + tracing::warn!("STUN server replied but is missing `XOR-MAPPED-ADDRESS"); + return true; + }; + + let observed_address = mapped_address.address(); + + let new_candidate = match Candidate::server_reflexive(observed_address, Protocol::Udp) { + Ok(c) => c, + Err(e) => { + tracing::debug!("Observed address is not a valid candidate: {e}"); + return true; // We still handled the packet correctly. + } + }; + + match &self.last_candidate { + Some(candidate) if candidate != &new_candidate => { + tracing::info!(current = %candidate, new = %new_candidate, "Updating server-reflexive candidate"); + } + None => { + tracing::info!(new = %new_candidate, "New server-reflexive candidate"); + } + _ => return true, + } + + self.last_candidate = Some(new_candidate.clone()); + self.buffered_events + .push_back(Event::NewServerReflexiveCandidate { + candidate: new_candidate, + }); + + true + } + + pub fn poll_event(&mut self) -> Option { + self.buffered_events.pop_front() + } + + pub fn poll_timeout(&mut self) -> Option { + match self.state { + State::Initial => None, + State::SentRequest { at, .. } => Some(at + STUN_TIMEOUT), + State::ReceivedResponse { at } => Some(at + STUN_REFRESH), + } + } + + pub fn handle_timeout(&mut self, now: Instant) { + self.last_now = Some(now); + + match self.state { + State::Initial => { + tracing::debug!(server = %self.server, "Sending new STUN request"); + } + State::SentRequest { id, at } if at + STUN_TIMEOUT <= now => { + tracing::debug!(?id, "STUN request timed out, sending new one"); + } + State::ReceivedResponse { at } if at + STUN_REFRESH <= now => { + tracing::debug!("Refreshing STUN binding"); + } + _ => return, + } + + let request = new_stun_request(); + + self.state = State::SentRequest { + id: request.transaction_id(), + at: now, + }; + + self.buffered_transmits.push_back(Transmit { + dst: self.server, + payload: encode(request), + }); + } + + pub fn poll_transmit(&mut self) -> Option { + self.buffered_transmits.pop_front() + } + + #[cfg(test)] + fn set_received_at(&mut self, address: SocketAddr, now: Instant) { + self.last_now = Some(now); + self.last_candidate = Some(Candidate::server_reflexive(address, Protocol::Udp).unwrap()); + self.state = State::ReceivedResponse { at: now }; + } +} + +#[derive(Debug)] +pub enum Event { + NewServerReflexiveCandidate { candidate: Candidate }, +} + +fn new_stun_request() -> Message { + Message::new( + stun_codec::MessageClass::Request, + rfc5389::methods::BINDING, + TransactionId::new(rand::random()), + ) +} + +fn encode(message: Message) -> Vec +where + A: Attribute, +{ + stun_codec::MessageEncoder::::default() + .encode_into_bytes(message) + .unwrap() +} + +#[derive(Debug)] +enum State { + Initial, + SentRequest { id: TransactionId, at: Instant }, + ReceivedResponse { at: Instant }, +} + +#[cfg(test)] +mod tests { + use super::*; + use bytecodec::DecodeExt; + use std::{ + net::{Ipv4Addr, SocketAddrV4}, + time::Duration, + }; + use stun_codec::{ + rfc5389::{attributes::XorMappedAddress, methods::BINDING}, + Message, + }; + + const SERVER1: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3478)); + const SERVER2: SocketAddr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 3478)); + const MAPPED_ADDRESS: SocketAddr = + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 9999)); + + #[test] + fn initial_binding_sends_request() { + let mut stun_binding = StunBinding::new(SERVER1); + + stun_binding.handle_timeout(Instant::now()); + let transmit = stun_binding.poll_transmit().unwrap(); + + assert_eq!(transmit.dst, SERVER1); + } + + #[test] + fn repeated_polling_does_not_generate_more_requests() { + let mut stun_binding = StunBinding::new(SERVER1); + + stun_binding.handle_timeout(Instant::now()); + + assert!(stun_binding.poll_transmit().is_some()); + assert!(stun_binding.poll_transmit().is_none()); + } + + #[test] + fn request_times_out_after_5_seconds_and_generates_new_request() { + let mut stun_binding = StunBinding::new(SERVER1); + + let start = Instant::now(); + stun_binding.handle_timeout(start); + + assert!(stun_binding.poll_transmit().is_some()); + assert!(stun_binding.poll_transmit().is_none()); + + assert_eq!( + stun_binding.poll_timeout().unwrap(), + start + Duration::from_secs(5) + ); + + // Nothing after 1 second .. + stun_binding.handle_timeout(start + Duration::from_secs(1)); + assert!(stun_binding.poll_transmit().is_none()); + + // Nothing after 2 seconds .. + stun_binding.handle_timeout(start + Duration::from_secs(2)); + assert!(stun_binding.poll_transmit().is_none()); + + stun_binding.handle_timeout(start + Duration::from_secs(5)); + assert!(stun_binding.poll_transmit().is_some()); + assert!(stun_binding.poll_transmit().is_none()); + } + + #[test] + fn mapped_address_is_emitted_as_event() { + let mut stun_binding = StunBinding::new(SERVER1); + let start = Instant::now(); + + stun_binding.handle_timeout(start); + + let request = stun_binding.poll_transmit().unwrap(); + let response = generate_stun_response(request, MAPPED_ADDRESS); + + let handled = + stun_binding.handle_input(SERVER1, &response, start + Duration::from_millis(200)); + assert!(handled); + + let Event::NewServerReflexiveCandidate { candidate } = stun_binding.poll_event().unwrap(); + + assert_eq!(candidate.addr(), MAPPED_ADDRESS); + } + + #[test] + fn stun_binding_is_refreshed_every_five_minutes() { + let start = Instant::now(); + + let mut stun_binding = StunBinding::new(SERVER1); + stun_binding.set_received_at(MAPPED_ADDRESS, start); + assert!(stun_binding.poll_transmit().is_none()); + + stun_binding.handle_timeout(start + Duration::from_secs(5 * 60)); + + assert!(stun_binding.poll_transmit().is_some()); + } + + #[test] + fn response_from_other_server_is_discarded() { + let mut stun_binding = StunBinding::new(SERVER1); + let start = Instant::now(); + + stun_binding.handle_timeout(start); + + let request = stun_binding.poll_transmit().unwrap(); + let response = generate_stun_response(request, MAPPED_ADDRESS); + + let handled = + stun_binding.handle_input(SERVER2, &response, start + Duration::from_millis(200)); + + assert!(!handled); + assert!(stun_binding.poll_event().is_none()); + } + + fn generate_stun_response(request: Transmit, mapped_address: SocketAddr) -> Vec { + let mut decoder = stun_codec::MessageDecoder::::default(); + + let message = decoder + .decode_from_bytes(&request.payload) + .unwrap() + .unwrap(); + + let transaction_id = message.transaction_id(); + + let mut response = Message::::new( + stun_codec::MessageClass::SuccessResponse, + BINDING, + transaction_id, + ); + response.add_attribute(stun_codec::rfc5389::Attribute::XorMappedAddress( + XorMappedAddress::new(mapped_address), + )); + + encode(response) + } +}