mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 02:18:47 +00:00
refactor(connlib): delay initialization of Sockets until we have a tokio runtime (#4286)
Our sockets need to be initialized within a tokio runtime context. To achieve this, we don't actually initialize anything on `Sockets::new`. Instead, we call `rebind` within the constructor of `Tunnel` which already runs in a tokio context. Fixes: #4282 --------- Signed-off-by: Jamil <jamilbk@users.noreply.github.com> Co-authored-by: Thomas Eizinger <thomas@eizinger.io> Co-authored-by: Reactor Scram <ReactorScram@users.noreply.github.com>
This commit is contained in:
@@ -388,28 +388,27 @@ fn connect(
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let sockets = Sockets::new()?;
|
||||
|
||||
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
|
||||
callback_handler.protect_file_descriptor(ip4_socket)?;
|
||||
}
|
||||
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
|
||||
callback_handler.protect_file_descriptor(ip6_socket)?;
|
||||
}
|
||||
let sockets = Sockets::with_protect({
|
||||
let callbacks = callback_handler.clone();
|
||||
move |fd| {
|
||||
callbacks
|
||||
.protect_file_descriptor(fd)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
});
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
sockets,
|
||||
private_key,
|
||||
Some(os_version),
|
||||
callback_handler.clone(),
|
||||
callback_handler,
|
||||
Some(MAX_PARTITION_TIME),
|
||||
runtime.handle().clone(),
|
||||
)?;
|
||||
|
||||
Ok(SessionWrapper {
|
||||
inner: session,
|
||||
callbacks: callback_handler,
|
||||
runtime,
|
||||
})
|
||||
}
|
||||
@@ -463,29 +462,11 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co
|
||||
|
||||
pub struct SessionWrapper {
|
||||
inner: Session,
|
||||
callbacks: CallbackHandler,
|
||||
|
||||
#[allow(dead_code)] // Only here so we don't drop the memory early.
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl SessionWrapper {
|
||||
fn reconnect(&self) -> Result<(), CallbackError> {
|
||||
let sockets = Sockets::new()?;
|
||||
|
||||
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
|
||||
self.callbacks.protect_file_descriptor(ip4_socket)?;
|
||||
}
|
||||
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
|
||||
self.callbacks.protect_file_descriptor(ip6_socket)?;
|
||||
}
|
||||
|
||||
self.inner.reconnect(sockets);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Pointers must be valid
|
||||
#[allow(non_snake_case)]
|
||||
@@ -528,11 +509,9 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
|
||||
#[allow(non_snake_case)]
|
||||
#[no_mangle]
|
||||
pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_reconnect(
|
||||
mut env: JNIEnv,
|
||||
_: JNIEnv,
|
||||
_: JClass,
|
||||
session: *const SessionWrapper,
|
||||
) {
|
||||
if let Err(e) = (*session).reconnect() {
|
||||
throw(&mut env, "java/lang/Exception", e.to_string());
|
||||
}
|
||||
(*session).inner.reconnect();
|
||||
}
|
||||
|
||||
@@ -205,7 +205,7 @@ impl WrappedSession {
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
Sockets::new().map_err(|err| err.to_string())?,
|
||||
Sockets::new(),
|
||||
private_key,
|
||||
os_version_override,
|
||||
CallbackHandler {
|
||||
|
||||
@@ -10,7 +10,7 @@ use connlib_shared::{
|
||||
messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId},
|
||||
Callbacks,
|
||||
};
|
||||
use firezone_tunnel::{ClientTunnel, Sockets};
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -37,7 +37,7 @@ pub struct Eventloop<C: Callbacks> {
|
||||
/// Commands that can be sent to the [`Eventloop`].
|
||||
pub enum Command {
|
||||
Stop,
|
||||
Reconnect(Sockets),
|
||||
Reconnect,
|
||||
SetDns(Vec<IpAddr>),
|
||||
}
|
||||
|
||||
@@ -71,9 +71,11 @@ where
|
||||
tracing::warn!("Failed to update DNS: {e}");
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Command::Reconnect(sockets))) => {
|
||||
Poll::Ready(Some(Command::Reconnect)) => {
|
||||
self.portal.reconnect();
|
||||
self.tunnel.reconnect(sockets);
|
||||
if let Err(e) = self.tunnel.reconnect() {
|
||||
tracing::warn!("Failed to reconnect tunnel: {e}");
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -67,11 +67,11 @@ impl Session {
|
||||
///
|
||||
/// - Close and re-open a connection to the portal.
|
||||
/// - Refresh all allocations
|
||||
/// - Replace the currently used [`Sockets`] with the provided one
|
||||
/// - Rebind local UDP sockets
|
||||
///
|
||||
/// # Implementation note
|
||||
///
|
||||
/// The reason we replace [`Sockets`] are:
|
||||
/// The reason we rebind the UDP sockets are:
|
||||
///
|
||||
/// 1. On MacOS, as socket bound to the unspecified IP cannot send to interfaces attached after the socket has been created.
|
||||
/// 2. Switching between networks changes the 3-tuple of the client.
|
||||
@@ -80,9 +80,9 @@ impl Session {
|
||||
/// Changing the IP would be enough for that.
|
||||
/// However, if the user would now change _back_ to the previous network,
|
||||
/// the TURN server would recognise the old allocation but the client already lost all its state associated with it.
|
||||
/// To avoid race-conditions like this, we initialize a new [`Sockets`] instance which allocates a new port.
|
||||
pub fn reconnect(&self, sockets: Sockets) {
|
||||
let _ = self.channel.send(Command::Reconnect(sockets));
|
||||
/// To avoid race-conditions like this, we rebind the sockets to a new port.
|
||||
pub fn reconnect(&self) {
|
||||
let _ = self.channel.send(Command::Reconnect);
|
||||
}
|
||||
|
||||
/// Sets a new set of upstream DNS servers for this [`Session`].
|
||||
@@ -118,7 +118,7 @@ async fn connect<CB>(
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone());
|
||||
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone())?;
|
||||
|
||||
let portal = PhoenixChannel::connect(
|
||||
Secret::new(url),
|
||||
|
||||
@@ -43,8 +43,13 @@ pub enum Input<'a, I> {
|
||||
}
|
||||
|
||||
impl Io {
|
||||
pub fn new(sockets: Sockets) -> Self {
|
||||
Self {
|
||||
/// Creates a new I/O abstraction
|
||||
///
|
||||
/// Must be called within a Tokio runtime context so we can bind the sockets.
|
||||
pub fn new(mut sockets: Sockets) -> io::Result<Self> {
|
||||
sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context.
|
||||
|
||||
Ok(Self {
|
||||
device: Device::new(),
|
||||
timeout: None,
|
||||
sockets,
|
||||
@@ -53,7 +58,7 @@ impl Io {
|
||||
Duration::from_secs(60),
|
||||
DNS_QUERIES_QUEUE_SIZE,
|
||||
),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn poll<'b>(
|
||||
@@ -115,8 +120,8 @@ impl Io {
|
||||
&self.sockets
|
||||
}
|
||||
|
||||
pub(crate) fn set_sockets(&mut self, sockets: Sockets) {
|
||||
self.sockets = sockets;
|
||||
pub fn sockets_mut(&mut self) -> &mut Sockets {
|
||||
&mut self.sockets
|
||||
}
|
||||
|
||||
pub fn set_upstream_dns_servers(
|
||||
|
||||
@@ -59,21 +59,27 @@ impl<CB> ClientTunnel<CB>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
|
||||
Self {
|
||||
io: Io::new(sockets),
|
||||
pub fn new(
|
||||
private_key: StaticSecret,
|
||||
sockets: Sockets,
|
||||
callbacks: CB,
|
||||
) -> std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
io: Io::new(sockets)?,
|
||||
callbacks,
|
||||
role_state: ClientState::new(private_key),
|
||||
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reconnect(&mut self, sockets: Sockets) {
|
||||
pub fn reconnect(&mut self) -> std::io::Result<()> {
|
||||
self.role_state.reconnect(Instant::now());
|
||||
self.io.set_sockets(sockets);
|
||||
self.io.sockets_mut().rebind()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<ClientEvent>> {
|
||||
@@ -149,16 +155,20 @@ impl<CB> GatewayTunnel<CB>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
|
||||
Self {
|
||||
io: Io::new(sockets),
|
||||
pub fn new(
|
||||
private_key: StaticSecret,
|
||||
sockets: Sockets,
|
||||
callbacks: CB,
|
||||
) -> std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
io: Io::new(sockets)?,
|
||||
callbacks,
|
||||
role_state: GatewayState::new(private_key),
|
||||
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<GatewayEvent>> {
|
||||
|
||||
@@ -13,10 +13,47 @@ use crate::Result;
|
||||
pub struct Sockets {
|
||||
socket_v4: Option<Socket>,
|
||||
socket_v6: Option<Socket>,
|
||||
|
||||
#[cfg(unix)]
|
||||
protect: Box<dyn Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static>,
|
||||
}
|
||||
|
||||
impl Default for Sockets {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sockets {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
#[cfg(unix)]
|
||||
pub fn with_protect(
|
||||
protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket_v4: None,
|
||||
socket_v6: None,
|
||||
#[cfg(unix)]
|
||||
protect: Box::new(protect),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
socket_v4: None,
|
||||
socket_v6: None,
|
||||
#[cfg(unix)]
|
||||
protect: Box::new(|_| Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
|
||||
match addr {
|
||||
SocketAddr::V4(_) => self.socket_v4.is_some(),
|
||||
SocketAddr::V6(_) => self.socket_v6.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rebind(&mut self) -> io::Result<()> {
|
||||
let socket_v4 = Socket::ip4();
|
||||
let socket_v6 = Socket::ip6();
|
||||
|
||||
@@ -39,31 +76,23 @@ impl Sockets {
|
||||
_ => (),
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
socket_v4: socket_v4.ok(),
|
||||
socket_v6: socket_v6.ok(),
|
||||
})
|
||||
}
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
|
||||
match addr {
|
||||
SocketAddr::V4(_) => self.socket_v4.is_some(),
|
||||
SocketAddr::V6(_) => self.socket_v6.is_some(),
|
||||
if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) {
|
||||
(self.protect)(fd)?;
|
||||
}
|
||||
|
||||
if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) {
|
||||
(self.protect)(fd)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn ip4_socket_fd(&self) -> Option<std::os::fd::RawFd> {
|
||||
use std::os::fd::AsRawFd;
|
||||
self.socket_v4 = socket_v4.ok();
|
||||
self.socket_v6 = socket_v6.ok();
|
||||
|
||||
self.socket_v4.as_ref().map(|s| s.socket.as_raw_fd())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn ip6_socket_fd(&self) -> Option<std::os::fd::RawFd> {
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
self.socket_v6.as_ref().map(|s| s.socket.as_raw_fd())
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flushes all buffered data on the sockets.
|
||||
|
||||
@@ -91,7 +91,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
|
||||
}
|
||||
|
||||
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
|
||||
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new()?, CallbackHandler);
|
||||
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?;
|
||||
|
||||
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
|
||||
Secret::new(login),
|
||||
|
||||
@@ -536,7 +536,7 @@ impl Controller {
|
||||
)?;
|
||||
let connlib = connlib_client_shared::Session::connect(
|
||||
login,
|
||||
Sockets::new()?,
|
||||
Sockets::new(),
|
||||
private_key,
|
||||
None,
|
||||
callback_handler.clone(),
|
||||
@@ -837,7 +837,7 @@ async fn run_controller(
|
||||
have_internet = new_have_internet;
|
||||
if let Some(session) = controller.session.as_mut() {
|
||||
tracing::debug!("Internet up/down changed, calling `Session::reconnect`");
|
||||
session.connlib.reconnect(Sockets::new()?);
|
||||
session.connlib.reconnect();
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -38,7 +38,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
Sockets::new()?,
|
||||
Sockets::new(),
|
||||
private_key,
|
||||
None,
|
||||
callbacks.clone(),
|
||||
@@ -62,7 +62,7 @@ async fn main() -> Result<()> {
|
||||
if sighup.poll_recv(cx).is_ready() {
|
||||
tracing::debug!("Received SIGHUP");
|
||||
|
||||
session.reconnect(Sockets::new()?);
|
||||
session.reconnect();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user