diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index f8573323c..97c8fe70b 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -5,7 +5,7 @@ use core::fmt; use rand::rngs::mock::StepRng; use rand::rngs::ThreadRng; use rand::Rng; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, VecDeque}; use std::hash::Hash; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::time::{Duration, Instant}; @@ -34,8 +34,8 @@ pub struct Server { allocations: HashMap, clients_by_allocation: HashMap, + allocations_by_port: HashMap, - used_ports: HashSet, pending_commands: VecDeque, next_allocation_id: AllocationId, @@ -109,7 +109,7 @@ impl Server { public_ip4_address, allocations: Default::default(), clients_by_allocation: Default::default(), - used_ports: Default::default(), + allocations_by_port: Default::default(), pending_commands: Default::default(), next_allocation_id: AllocationId(1), rng: rand::thread_rng(), @@ -241,7 +241,7 @@ where return Err(AllocationMismatch.into()); } - if self.used_ports.len() == MAX_AVAILABLE_PORTS as usize { + if self.allocations_by_port.len() == MAX_AVAILABLE_PORTS as usize { return Err(InsufficientCapacity.into()); } @@ -361,15 +361,14 @@ where // First, find an unused port. assert!( - self.used_ports.len() < MAX_AVAILABLE_PORTS as usize, + self.allocations_by_port.len() < MAX_AVAILABLE_PORTS as usize, "No more ports available; this would loop forever" ); let port = loop { let candidate = self.rng.gen_range(LOWEST_PORT..HIGHEST_PORT); - if !self.used_ports.contains(&candidate) { - self.used_ports.insert(candidate); + if !self.allocations_by_port.contains_key(&candidate) { break candidate; } }; @@ -377,6 +376,8 @@ where // Second, grab a new allocation ID. let id = self.next_allocation_id.next(); + self.allocations_by_port.insert(port, id); + Allocation { id, port, @@ -439,7 +440,7 @@ impl Server { public_ip4_address: SocketAddrV4::new(local_ip4_address, 3478), allocations: HashMap::new(), clients_by_allocation: Default::default(), - used_ports: HashSet::new(), + allocations_by_port: Default::default(), next_allocation_id: AllocationId::default(), pending_commands: VecDeque::new(), rng: StepRng::new(0, 0),