feat(connection): use STUN to generate server-reflexive candidate (#3268)

Currently, `firezone-connection` can only handle connections on a LAN.
Via the use of a STUN server, we can discover our public IP and attempt
to direct, hole-punched connection across multiple subnets.
This commit is contained in:
Thomas Eizinger
2024-01-18 20:11:07 -08:00
committed by GitHub
parent 613ca00b1c
commit 66c85e28b0
13 changed files with 737 additions and 61 deletions

View File

@@ -151,7 +151,10 @@ jobs:
strategy:
fail-fast: false
matrix:
file: ['docker-compose.lan.yml']
file: [
'docker-compose.lan.yml',
'docker-compose.wan.yml'
]
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/gcp-docker-login
@@ -161,7 +164,7 @@ jobs:
- name: Run ${{ matrix.file }} test
run: |
sudo sysctl -w vm.overcommit_memory=1
docker compose -f rust/connection-tests/${{ matrix.file }} up --exit-code-from dialer --abort-on-container-exit
timeout 600 docker compose -f rust/connection-tests/${{ matrix.file }} up --exit-code-from dialer --abort-on-container-exit
integration-tests:
needs: build-images

View File

@@ -157,3 +157,27 @@ Prebuilt Binaries License
details, features, specifications, capabilities, functions, licensing
terms, release dates, APIs, ABIs, general availability, or other
characteristics of the Software.
===
Portions of this product (`rust/connection-tests`) contain derivative work from https://github.com/libp2p/test-plans. Its license is included below:
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

2
rust/Cargo.lock generated
View File

@@ -1975,11 +1975,13 @@ version = "1.0.0"
dependencies = [
"anyhow",
"boringtun",
"bytecodec",
"firezone-relay",
"pnet_packet",
"rand 0.8.5",
"secrecy",
"str0m",
"stun_codec",
"thiserror",
"tracing",
]

View File

@@ -124,7 +124,7 @@ CMD $PACKAGE
FROM runtime AS debug
RUN set -xe \
&& apk add --no-cache iperf3 jq
&& apk add --no-cache iperf3 bind-tools iproute2 jq
ARG TARGET
COPY --from=builder /build/target/${TARGET}/debug/${PACKAGE} .

View File

@@ -0,0 +1,147 @@
version: "3.8"
name: wan-hp-integration-test
services:
dialer:
build:
target: debug
context: ..
args:
PACKAGE: firezone-connection-tests
cache_from:
- type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main
image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main}
environment:
ROLE: "dialer"
cap_add:
- NET_ADMIN
entrypoint: /bin/sh
command:
- -c
- |
set -ex
ROUTER_IP=$$(dig +short dialer_router)
INTERNET_SUBNET=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/networks/wan-hp-integration-test_wan | jq -r '.IPAM.Config[0].Subnet')
ip route add $$INTERNET_SUBNET via $$ROUTER_IP dev eth0
export STUN_SERVER=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-relay-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress')
export REDIS_HOST=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-redis-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress')
firezone-connection-tests
depends_on:
- dialer_router
- redis
networks:
- lan1
volumes:
- /var/run/docker.sock:/var/run/docker.sock
dialer_router:
init: true
build:
context: ./router
cap_add:
- NET_ADMIN
networks:
- lan1
- wan
listener:
build:
target: debug
context: ..
args:
PACKAGE: firezone-connection-tests
cache_from:
- type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/connection-tests:main
image: us-east1-docker.pkg.dev/firezone-staging/firezone/connection-tests:${VERSION:-main}
init: true
environment:
ROLE: "listener"
entrypoint: /bin/sh
command:
- -c
- |
set -ex
ROUTER_IP=$$(dig +short listener_router)
INTERNET_SUBNET=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/networks/wan-hp-integration-test_wan | jq -r '.IPAM.Config[0].Subnet')
ip route add $$INTERNET_SUBNET via $$ROUTER_IP dev eth0
export STUN_SERVER=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-relay-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress')
export REDIS_HOST=$$(curl --fail --silent --unix-socket /var/run/docker.sock http://localhost/containers/wan-hp-integration-test-redis-1/json | jq -r '.NetworkSettings.Networks."wan-hp-integration-test_wan".IPAddress')
firezone-connection-tests
cap_add:
- NET_ADMIN
depends_on:
- listener_router
- redis
networks:
- lan2
volumes:
- /var/run/docker.sock:/var/run/docker.sock
listener_router:
init: true
build:
context: ./router
cap_add:
- NET_ADMIN
networks:
- lan2
- wan
relay:
environment:
LOWEST_PORT: 55555
HIGHEST_PORT: 55666
RUST_LOG: "debug"
RUST_BACKTRACE: 1
build:
target: debug
context: ..
cache_from:
- type=registry,ref=us-east1-docker.pkg.dev/firezone-staging/cache/relay:main
args:
PACKAGE: firezone-relay
image: us-east1-docker.pkg.dev/firezone-staging/firezone/relay:${VERSION:-main}
init: true
healthcheck:
test: ["CMD-SHELL", "lsof -i UDP | grep firezone-relay"]
start_period: 20s
interval: 30s
retries: 5
timeout: 5s
entrypoint: /bin/sh
command:
- -c
- |
set -ex;
export PUBLIC_IP4_ADDR=$(ip -json addr show eth0 | jq '.[0].addr_info[0].local' -r)
firezone-relay
ports:
# XXX: Only 111 ports are used for local dev / testing because Docker Desktop
# allocates a userland proxy process for each forwarded port X_X.
#
# Large ranges here will bring your machine to its knees.
- "55555-55666:55555-55666/udp"
- 3478:3478/udp
networks:
- wan
redis:
image: "redis:7-alpine"
healthcheck:
test: ["CMD-SHELL", "echo 'ready';"]
networks:
- wan
networks:
lan1:
lan2:
wan:

View File

@@ -0,0 +1,11 @@
FROM debian:12-slim
ARG DEBIAN_FRONTEND=noninteractive
RUN --mount=type=cache,target=/var/cache/apt apt-get update && apt-get -y install iproute2 nftables conntrack
COPY *.sh /scripts/
RUN chmod +x /scripts/*.sh
HEALTHCHECK CMD [ "sh", "-c", "test $(cat /tmp/setup_done) = 1" ]
ENTRYPOINT ["./scripts/run.sh"]

View File

@@ -0,0 +1,18 @@
# Router
This directory contains a Debian-based router implemented on top of nftables.
It expects to be run with two network interfaces:
- `eth1`: The "external" interface.
- `eth0`: The "internal" interface.
The order of these interfaces depends on lexical sorting the docker networks names.
The order of these is important.
The router cannot possibly know which one is which and thus assumes that `eth0` is the external one and `eth1` the internal one.
The firewall is set up to take incoming traffic on `eth1` and forward + masquerade it to `eth0`.
It also expects an env variable `DELAY_MS` to be set and will apply this delay as part of the routing process[^1].
[^1]: This is done via `tc qdisc` which only works for egress traffic. To ensure the delay applies in both directions, we divide it by 2 and apply it on both interfaces.

View File

@@ -0,0 +1,18 @@
#!/bin/sh
set -ex
# Set up NAT
nft add table ip nat
nft add chain ip nat postrouting { type nat hook postrouting priority 100 \; }
nft add rule ip nat postrouting masquerade
# Assumption after a long debugging session involving Gabi, Jamil and Thomas:
# On the same machine, the kernel cannot differentiate between incoming and outgoing packets across different network namespaces within the firewall and NAT mapping table.
# As a result, even UDP hole-punching is time-sensitive and we thus need to make sure that we first send a packet _out_ through the router before the other one is incoming.
# To achieve this, we set an absurdly high latency of 300ms for the WAN network.
tc qdisc add dev eth1 root netem delay 300ms
echo "1" > /tmp/setup_done # This will be checked by our docker HEALTHCHECK
conntrack --event --proto UDP --output timestamp # Display a real-time log of NAT events in the kernel.

View File

@@ -1,6 +1,7 @@
use std::{
collections::HashSet,
future::poll_fn,
net::Ipv4Addr,
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
task::{Context, Poll},
time::Instant,
@@ -42,6 +43,13 @@ async fn main() -> Result<()> {
.ip
.to_std();
let stun_server = std::env::var("STUN_SERVER")
.ok()
.map(|a| a.parse::<IpAddr>())
.transpose()
.context("Failed to parse `STUN_SERVER`")?
.map(|ip| SocketAddr::new(ip, 3478));
tracing::info!(%listen_addr);
let redis_host = std::env::var("REDIS_HOST").context("Missing REDIS_HOST env var")?;
@@ -63,7 +71,8 @@ async fn main() -> Result<()> {
let mut pool = ClientConnectionPool::<u64>::new(private_key);
pool.add_local_interface(socket_addr);
let offer = pool.new_connection(1, vec![], vec![]);
let offer =
pool.new_connection(1, stun_server.into_iter().collect(), HashSet::default());
redis_connection
.rpush(
@@ -124,6 +133,9 @@ async fn main() -> Result<()> {
start = Instant::now();
eventloop.send_to(conn, ip4_udp_ping_packet(source, dst, &ping_body).into())?;
}
Event::ConnectionFailed { conn } => {
anyhow::bail!("Failed to establish connection: {conn}");
}
}
}
@@ -156,8 +168,8 @@ async fn main() -> Result<()> {
},
},
offer.public_key.into(),
vec![],
vec![],
stun_server.into_iter().collect(),
HashSet::default(),
);
redis_connection
@@ -188,6 +200,9 @@ async fn main() -> Result<()> {
.context("Failed to push candidate")?;
}
Event::ConnectionEstablished { .. } => { }
Event::ConnectionFailed { conn } => {
anyhow::bail!("Failed to establish connection: {conn}");
}
}
}
@@ -340,6 +355,9 @@ impl<T> Eventloop<T> {
Some(firezone_connection::Event::ConnectionEstablished(conn)) => {
return Poll::Ready(Ok(Event::ConnectionEstablished { conn }))
}
Some(firezone_connection::Event::ConnectionFailed(conn)) => {
return Poll::Ready(Ok(Event::ConnectionFailed { conn }))
}
None => {}
}
@@ -389,6 +407,9 @@ enum Event {
ConnectionEstablished {
conn: u64,
},
ConnectionFailed {
conn: u64,
},
}
async fn sleep_until(deadline: Instant) -> Instant {

View File

@@ -14,3 +14,5 @@ secrecy = { workspace = true }
str0m = { workspace = true }
thiserror = "1"
tracing = "0.1"
stun_codec = "0.3.4"
bytecodec = "0.4.15"

View File

@@ -1,6 +1,7 @@
mod index;
mod ip_packet;
mod pool;
mod stun_binding;
pub use ip_packet::IpPacket;
pub use pool::{

View File

@@ -17,9 +17,10 @@ use std::{
};
use str0m::ice::{IceAgent, IceAgentEvent, IceCreds};
use str0m::net::{Protocol, Receive};
use str0m::{Candidate, StunMessage};
use str0m::{Candidate, IceConnectionState, StunMessage};
use crate::index::IndexLfsr;
use crate::stun_binding::StunBinding;
use crate::IpPacket;
// Note: Taken from boringtun
@@ -44,6 +45,8 @@ pub struct ConnectionPool<T, TId> {
next_rate_limiter_reset: Option<Instant>,
stun_servers: HashMap<SocketAddr, StunBinding>,
initial_connections: HashMap<TId, InitialConnection>,
negotiated_connections: HashMap<TId, Connection>,
pending_events: VecDeque<Event<TId>>,
@@ -85,6 +88,7 @@ where
pending_events: VecDeque::default(),
initial_connections: HashMap::default(),
buffer: Box::new([0u8; MAX_UDP_SIZE]),
stun_servers: HashMap::default(),
}
}
@@ -127,9 +131,20 @@ where
return Err(Error::UnknownInterface);
}
// TODO: First thing we need to check if the message is from one of our STUN / TURN servers AND it is a STUN message (starts with 0x03)
// ...
// ...
// First, check if a `StunBinding` wants the packet
if let Some(binding) = self.stun_servers.get_mut(&from) {
if binding.handle_input(from, packet, now) {
// If it handled the packet, drain its events to ensure we update the candidates of all connections.
drain_binding_events(
from,
binding,
&mut self.initial_connections,
&mut self.negotiated_connections,
&mut self.pending_events,
);
return Ok(None);
}
}
// Next: If we can parse the message as a STUN message, cycle through all agents to check which one it is for.
if let Ok(stun_message) = StunMessage::parse(packet) {
@@ -270,8 +285,9 @@ where
conn.possible_sockets.insert(source);
// TODO: Here is where we'd allocate channels.
}
IceAgentEvent::IceRestart(_) => {}
IceAgentEvent::IceConnectionStateChange(_) => {}
IceAgentEvent::IceConnectionStateChange(IceConnectionState::Disconnected) => {
return Some(Event::ConnectionFailed(*id));
}
IceAgentEvent::NominatedSend { destination, .. } => match conn.remote_socket {
Some(old) if old != destination => {
tracing::info!(%id, new = %destination, %old, "Migrating connection to peer");
@@ -285,6 +301,7 @@ where
}
_ => {}
},
_ => {}
}
}
}
@@ -304,6 +321,9 @@ where
for c in self.negotiated_connections.values_mut() {
connection_timeout = earliest(connection_timeout, c.poll_timeout());
}
for b in self.stun_servers.values_mut() {
connection_timeout = earliest(connection_timeout, b.poll_timeout());
}
earliest(connection_timeout, self.next_rate_limiter_reset)
}
@@ -316,6 +336,10 @@ where
self.buffered_transmits.extend(c.handle_timeout(now));
}
for binding in self.stun_servers.values_mut() {
binding.handle_timeout(now);
}
let next_reset = *self.next_rate_limiter_reset.get_or_insert(now);
if now >= next_reset {
@@ -335,6 +359,12 @@ where
}
}
for binding in self.stun_servers.values_mut() {
if let Some(transmit) = binding.poll_transmit() {
return Some(transmit);
}
}
for conn in self.negotiated_connections.values_mut() {
if let Some(transmit) = conn.agent.poll_transmit() {
return Some(Transmit {
@@ -356,25 +386,23 @@ where
maybe_initial_connection.or(maybe_established_connection)
}
}
impl<TId> ConnectionPool<Client, TId>
where
TId: Eq + Hash + Copy,
{
/// Create a new connection indexed by the given ID.
///
/// Out of all configured STUN and TURN servers, the connection will only use the ones provided here.
/// The returned [`Offer`] must be passed to the remote via a signalling channel.
pub fn new_connection(
fn upsert_stun_servers(&mut self, servers: &HashSet<SocketAddr>) {
for server in servers {
if !self.stun_servers.contains_key(server) {
tracing::debug!(address = %server, "Adding new STUN server");
self.stun_servers.insert(*server, StunBinding::new(*server));
}
}
}
fn seed_agent_with_local_candidates(
&mut self,
id: TId,
allowed_stun_servers: Vec<SocketAddr>,
allowed_turn_servers: Vec<SocketAddr>,
) -> Offer {
let mut agent = IceAgent::new();
agent.set_controlling(true);
connection: TId,
agent: &mut IceAgent,
allowed_stun_servers: &HashSet<SocketAddr>,
) {
for local in self.local_interfaces.iter().copied() {
let candidate = match Candidate::host(local, Protocol::Udp) {
Ok(c) => c,
@@ -384,14 +412,52 @@ where
}
};
if agent.add_local_candidate(candidate.clone()) {
self.pending_events.push_back(Event::SignalIceCandidate {
connection: id,
candidate: candidate.to_sdp_string(),
});
}
add_local_candidate(
connection,
agent,
candidate.clone(),
&mut self.pending_events,
);
}
for candidate in self.stun_servers.iter().filter_map(|(server, binding)| {
let candidate = allowed_stun_servers
.contains(server)
.then(|| binding.candidate())??;
Some(candidate)
}) {
add_local_candidate(
connection,
agent,
candidate.clone(),
&mut self.pending_events,
);
}
}
}
impl<TId> ConnectionPool<Client, TId>
where
TId: Eq + Hash + Copy + fmt::Display,
{
/// Create a new connection indexed by the given ID.
///
/// Out of all configured STUN and TURN servers, the connection will only use the ones provided here.
/// The returned [`Offer`] must be passed to the remote via a signalling channel.
pub fn new_connection(
&mut self,
id: TId,
allowed_stun_servers: HashSet<SocketAddr>,
allowed_turn_servers: HashSet<SocketAddr>,
) -> Offer {
self.upsert_stun_servers(&allowed_stun_servers);
let mut agent = IceAgent::new();
agent.set_controlling(true);
self.seed_agent_with_local_candidates(id, &mut agent, &allowed_stun_servers);
let session_key = Secret::new(random());
let ice_creds = agent.local_credentials();
@@ -440,7 +506,7 @@ where
self.index.next(),
Some(self.rate_limiter.clone()),
),
_stun_servers: initial.stun_servers,
stun_servers: initial.stun_servers,
_turn_servers: initial.turn_servers,
next_timer_update: None,
remote_socket: None,
@@ -452,16 +518,18 @@ where
impl<TId> ConnectionPool<Server, TId>
where
TId: Eq + Hash + Copy,
TId: Eq + Hash + Copy + fmt::Display,
{
pub fn accept_connection(
&mut self,
id: TId,
offer: Offer,
remote: PublicKey,
allowed_stun_servers: Vec<SocketAddr>,
allowed_turn_servers: Vec<SocketAddr>,
allowed_stun_servers: HashSet<SocketAddr>,
allowed_turn_servers: HashSet<SocketAddr>,
) -> Answer {
self.upsert_stun_servers(&allowed_stun_servers);
let mut agent = IceAgent::new();
agent.set_controlling(false);
agent.set_remote_credentials(IceCreds {
@@ -475,22 +543,7 @@ where
},
};
for local in self.local_interfaces.iter().copied() {
let candidate = match Candidate::host(local, Protocol::Udp) {
Ok(c) => c,
Err(e) => {
tracing::debug!("Failed to generate host candidate from addr: {e}");
continue;
}
};
if agent.add_local_candidate(candidate.clone()) {
self.pending_events.push_back(Event::SignalIceCandidate {
connection: id,
candidate: candidate.to_sdp_string(),
});
}
}
self.seed_agent_with_local_candidates(id, &mut agent, &allowed_stun_servers);
self.negotiated_connections.insert(
id,
@@ -504,7 +557,7 @@ where
self.index.next(),
Some(self.rate_limiter.clone()),
),
_stun_servers: allowed_stun_servers,
stun_servers: allowed_stun_servers,
_turn_servers: allowed_turn_servers,
next_timer_update: None,
remote_socket: None,
@@ -516,6 +569,55 @@ where
}
}
fn drain_binding_events<TId>(
server: SocketAddr,
binding: &mut StunBinding,
initial_connections: &mut HashMap<TId, InitialConnection>,
negotiated_connections: &mut HashMap<TId, Connection>,
pending_events: &mut VecDeque<Event<TId>>,
) where
TId: Copy,
{
while let Some(event) = binding.poll_event() {
match event {
crate::stun_binding::Event::NewServerReflexiveCandidate { candidate } => {
// TODO: Reduce duplication between initial and negotiated connections
for (id, c) in initial_connections.iter_mut() {
if !c.stun_servers.contains(&server) {
continue;
}
add_local_candidate(*id, &mut c.agent, candidate.clone(), pending_events);
}
for (id, c) in negotiated_connections.iter_mut() {
if !c.stun_servers.contains(&server) {
continue;
}
add_local_candidate(*id, &mut c.agent, candidate.clone(), pending_events);
}
}
};
}
}
fn add_local_candidate<TId>(
id: TId,
agent: &mut IceAgent,
candidate: Candidate,
pending_events: &mut VecDeque<Event<TId>>,
) {
let is_new = agent.add_local_candidate(candidate.clone());
if is_new {
pending_events.push_back(Event::SignalIceCandidate {
connection: id,
candidate: candidate.to_sdp_string(),
})
}
}
pub struct Offer {
/// The Wireguard session key for a connection.
pub session_key: Secret<[u8; 32]>,
@@ -542,8 +644,14 @@ pub enum Event<TId> {
candidate: String,
},
ConnectionEstablished(TId),
/// We tested all candidates and failed to establish a connection.
///
/// This condition will not resolve unless more candidates are added or the network conditions change.
ConnectionFailed(TId),
}
#[derive(Debug)]
pub struct Transmit {
pub dst: SocketAddr,
pub payload: Vec<u8>,
@@ -552,8 +660,8 @@ pub struct Transmit {
pub struct InitialConnection {
agent: IceAgent,
session_key: Secret<[u8; 32]>,
stun_servers: Vec<SocketAddr>,
turn_servers: Vec<SocketAddr>,
stun_servers: HashSet<SocketAddr>,
turn_servers: HashSet<SocketAddr>,
}
struct Connection {
@@ -567,8 +675,8 @@ struct Connection {
// Socket addresses from which we might receive data (even before we are connected).
possible_sockets: HashSet<SocketAddr>,
_stun_servers: Vec<SocketAddr>,
_turn_servers: Vec<SocketAddr>,
stun_servers: HashSet<SocketAddr>,
_turn_servers: HashSet<SocketAddr>,
}
impl Connection {

View File

@@ -0,0 +1,321 @@
use crate::pool::Transmit;
use bytecodec::{DecodeExt, EncodeExt};
use std::{
collections::VecDeque,
net::SocketAddr,
time::{Duration, Instant},
};
use str0m::{net::Protocol, Candidate};
use stun_codec::{
rfc5389::{self, attributes::XorMappedAddress},
Attribute, Message, TransactionId,
};
const STUN_TIMEOUT: Duration = Duration::from_secs(5);
const STUN_REFRESH: Duration = Duration::from_secs(5 * 60);
/// A SANS-IO state machine that obtains a server-reflexive candidate from the configured STUN server.
#[derive(Debug)]
pub struct StunBinding {
server: SocketAddr,
last_candidate: Option<Candidate>,
state: State,
last_now: Option<Instant>,
buffered_transmits: VecDeque<Transmit>,
buffered_events: VecDeque<Event>,
}
impl StunBinding {
pub fn new(server: SocketAddr) -> Self {
Self {
server,
last_candidate: None,
state: State::Initial,
last_now: None,
buffered_transmits: Default::default(),
buffered_events: Default::default(),
}
}
pub fn candidate(&self) -> Option<Candidate> {
self.last_candidate.clone()
}
pub fn handle_input(&mut self, from: SocketAddr, packet: &[u8], now: Instant) -> bool {
self.last_now = Some(now); // TODO: Do we need to do any other updates here?
if from != self.server {
return false;
}
let Ok(Ok(message)) =
stun_codec::MessageDecoder::<stun_codec::rfc5389::Attribute>::default()
.decode_from_bytes(packet)
else {
return false;
};
match self.state {
State::SentRequest { id, .. } if id == message.transaction_id() => {
self.state = State::ReceivedResponse { at: now }
}
_ => {
return false;
}
}
let Some(mapped_address) = message.get_attribute::<XorMappedAddress>() else {
tracing::warn!("STUN server replied but is missing `XOR-MAPPED-ADDRESS");
return true;
};
let observed_address = mapped_address.address();
let new_candidate = match Candidate::server_reflexive(observed_address, Protocol::Udp) {
Ok(c) => c,
Err(e) => {
tracing::debug!("Observed address is not a valid candidate: {e}");
return true; // We still handled the packet correctly.
}
};
match &self.last_candidate {
Some(candidate) if candidate != &new_candidate => {
tracing::info!(current = %candidate, new = %new_candidate, "Updating server-reflexive candidate");
}
None => {
tracing::info!(new = %new_candidate, "New server-reflexive candidate");
}
_ => return true,
}
self.last_candidate = Some(new_candidate.clone());
self.buffered_events
.push_back(Event::NewServerReflexiveCandidate {
candidate: new_candidate,
});
true
}
pub fn poll_event(&mut self) -> Option<Event> {
self.buffered_events.pop_front()
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
match self.state {
State::Initial => None,
State::SentRequest { at, .. } => Some(at + STUN_TIMEOUT),
State::ReceivedResponse { at } => Some(at + STUN_REFRESH),
}
}
pub fn handle_timeout(&mut self, now: Instant) {
self.last_now = Some(now);
match self.state {
State::Initial => {
tracing::debug!(server = %self.server, "Sending new STUN request");
}
State::SentRequest { id, at } if at + STUN_TIMEOUT <= now => {
tracing::debug!(?id, "STUN request timed out, sending new one");
}
State::ReceivedResponse { at } if at + STUN_REFRESH <= now => {
tracing::debug!("Refreshing STUN binding");
}
_ => return,
}
let request = new_stun_request();
self.state = State::SentRequest {
id: request.transaction_id(),
at: now,
};
self.buffered_transmits.push_back(Transmit {
dst: self.server,
payload: encode(request),
});
}
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.buffered_transmits.pop_front()
}
#[cfg(test)]
fn set_received_at(&mut self, address: SocketAddr, now: Instant) {
self.last_now = Some(now);
self.last_candidate = Some(Candidate::server_reflexive(address, Protocol::Udp).unwrap());
self.state = State::ReceivedResponse { at: now };
}
}
#[derive(Debug)]
pub enum Event {
NewServerReflexiveCandidate { candidate: Candidate },
}
fn new_stun_request() -> Message<rfc5389::Attribute> {
Message::new(
stun_codec::MessageClass::Request,
rfc5389::methods::BINDING,
TransactionId::new(rand::random()),
)
}
fn encode<A>(message: Message<A>) -> Vec<u8>
where
A: Attribute,
{
stun_codec::MessageEncoder::<A>::default()
.encode_into_bytes(message)
.unwrap()
}
#[derive(Debug)]
enum State {
Initial,
SentRequest { id: TransactionId, at: Instant },
ReceivedResponse { at: Instant },
}
#[cfg(test)]
mod tests {
use super::*;
use bytecodec::DecodeExt;
use std::{
net::{Ipv4Addr, SocketAddrV4},
time::Duration,
};
use stun_codec::{
rfc5389::{attributes::XorMappedAddress, methods::BINDING},
Message,
};
const SERVER1: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3478));
const SERVER2: SocketAddr =
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 3478));
const MAPPED_ADDRESS: SocketAddr =
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 9999));
#[test]
fn initial_binding_sends_request() {
let mut stun_binding = StunBinding::new(SERVER1);
stun_binding.handle_timeout(Instant::now());
let transmit = stun_binding.poll_transmit().unwrap();
assert_eq!(transmit.dst, SERVER1);
}
#[test]
fn repeated_polling_does_not_generate_more_requests() {
let mut stun_binding = StunBinding::new(SERVER1);
stun_binding.handle_timeout(Instant::now());
assert!(stun_binding.poll_transmit().is_some());
assert!(stun_binding.poll_transmit().is_none());
}
#[test]
fn request_times_out_after_5_seconds_and_generates_new_request() {
let mut stun_binding = StunBinding::new(SERVER1);
let start = Instant::now();
stun_binding.handle_timeout(start);
assert!(stun_binding.poll_transmit().is_some());
assert!(stun_binding.poll_transmit().is_none());
assert_eq!(
stun_binding.poll_timeout().unwrap(),
start + Duration::from_secs(5)
);
// Nothing after 1 second ..
stun_binding.handle_timeout(start + Duration::from_secs(1));
assert!(stun_binding.poll_transmit().is_none());
// Nothing after 2 seconds ..
stun_binding.handle_timeout(start + Duration::from_secs(2));
assert!(stun_binding.poll_transmit().is_none());
stun_binding.handle_timeout(start + Duration::from_secs(5));
assert!(stun_binding.poll_transmit().is_some());
assert!(stun_binding.poll_transmit().is_none());
}
#[test]
fn mapped_address_is_emitted_as_event() {
let mut stun_binding = StunBinding::new(SERVER1);
let start = Instant::now();
stun_binding.handle_timeout(start);
let request = stun_binding.poll_transmit().unwrap();
let response = generate_stun_response(request, MAPPED_ADDRESS);
let handled =
stun_binding.handle_input(SERVER1, &response, start + Duration::from_millis(200));
assert!(handled);
let Event::NewServerReflexiveCandidate { candidate } = stun_binding.poll_event().unwrap();
assert_eq!(candidate.addr(), MAPPED_ADDRESS);
}
#[test]
fn stun_binding_is_refreshed_every_five_minutes() {
let start = Instant::now();
let mut stun_binding = StunBinding::new(SERVER1);
stun_binding.set_received_at(MAPPED_ADDRESS, start);
assert!(stun_binding.poll_transmit().is_none());
stun_binding.handle_timeout(start + Duration::from_secs(5 * 60));
assert!(stun_binding.poll_transmit().is_some());
}
#[test]
fn response_from_other_server_is_discarded() {
let mut stun_binding = StunBinding::new(SERVER1);
let start = Instant::now();
stun_binding.handle_timeout(start);
let request = stun_binding.poll_transmit().unwrap();
let response = generate_stun_response(request, MAPPED_ADDRESS);
let handled =
stun_binding.handle_input(SERVER2, &response, start + Duration::from_millis(200));
assert!(!handled);
assert!(stun_binding.poll_event().is_none());
}
fn generate_stun_response(request: Transmit, mapped_address: SocketAddr) -> Vec<u8> {
let mut decoder = stun_codec::MessageDecoder::<stun_codec::rfc5389::Attribute>::default();
let message = decoder
.decode_from_bytes(&request.payload)
.unwrap()
.unwrap();
let transaction_id = message.transaction_id();
let mut response = Message::<rfc5389::Attribute>::new(
stun_codec::MessageClass::SuccessResponse,
BINDING,
transaction_id,
);
response.add_attribute(stun_codec::rfc5389::Attribute::XorMappedAddress(
XorMappedAddress::new(mapped_address),
));
encode(response)
}
}