Signout when connlib calls onDisconnect (#2687)

Fixes #2304.

- If the tunnel disconnects because of connlib, or because the tunnel is
incorrectly configured (we don't expect that to happen), the user is
signed out.
- If the tunnel disconnects because the user disconnected it in
Settings, the user is not signed out, and no alert is shown
- If the tunnel disconnects because the OS brought it down for some
other reason (not sure what it could be), the user is not signed out,
and an alert is shown. Alert will be shown only if the app is running at
that time.
This commit is contained in:
Roopesh Chander
2023-11-25 16:38:51 +05:30
committed by GitHub
parent b9cd94ec82
commit c283610402
10 changed files with 359 additions and 59 deletions

View File

@@ -1,3 +1,3 @@
[codespell]
skip = ./website/.next,./website/pnpm-lock.yaml,./rust/target,Cargo.lock,./website/docs/reference/api/*.mdx,./erl_crash.dump,./apps/*/erl_crash.dump,./cover,./vendor,*.json,seeds.exs,./**/node_modules,./deps,./priv/static,./priv/plts,./**/priv/static,./.git,./_build
ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo
ignore-words-list = optin,crate,keypair,keypairs,iif,statics,wee,anull,commitish,inout,fo,superceded

View File

@@ -16,7 +16,7 @@ import SwiftUI
private let logger = Logger.make(for: MainViewModel.self)
private var cancellables: Set<AnyCancellable> = []
private let appStore: AppStore
let appStore: AppStore
@Dependency(\.mainQueue) private var mainQueue
@Published var loginStatus: AuthStore.LoginStatus = .uninitialized
@@ -60,13 +60,7 @@ import SwiftUI
}
func signOutButtonTapped() {
Task {
do {
try await appStore.auth.signOut()
} catch {
logger.error("Error signing out: \(String(describing: error))")
}
}
appStore.auth.signOut()
}
func startTunnel() async {
@@ -80,7 +74,13 @@ import SwiftUI
}
func stopTunnel() {
appStore.tunnel.stop()
Task {
do {
try await appStore.tunnel.stop()
} catch {
logger.error("\(#function): Error stopping tunnel: \(error)")
}
}
}
}

View File

@@ -47,6 +47,10 @@ public struct SharedAccess {
return nil
}
public static var tunnelShutdownEventFileURL: URL {
baseFolderURL.appendingPathComponent("tunnel_shutdown_event_data.json")
}
private static func ensureDirectoryExists(at path: String) -> Bool {
let fileManager = FileManager.default
do {

View File

@@ -0,0 +1,120 @@
//
// DisconnectReason.swift
// (c) 2023 Firezone, Inc.
// LICENSE: Apache-2.0
//
import Foundation
import NetworkExtension
import os
enum TunnelShutdownEventError: Error {
case decodeError
case cannotGetFileURL
}
public struct TunnelShutdownEvent: Codable, CustomStringConvertible {
private static let logger = Logger.make(for: TunnelShutdownEvent.self)
public enum Reason: Codable, CustomStringConvertible {
case stopped(NEProviderStopReason)
case connlibConnectFailure
case connlibDisconnected
case badTunnelConfiguration
case tokenNotFound
case networkSettingsApplyFailure
case invalidAdapterState
public var description: String {
switch self {
case .stopped(let reason): return "stopped(reason code: \(reason.rawValue))"
case .connlibConnectFailure: return "connlib connection failure"
case .connlibDisconnected: return "connlib disconnected"
case .badTunnelConfiguration: return "bad tunnel configuration"
case .tokenNotFound: return "token not found"
case .networkSettingsApplyFailure: return "network settings apply failure"
case .invalidAdapterState: return "invalid adapter state"
}
}
}
public enum Action {
case signoutImmediately
case retryThenSignout
}
public let reason: TunnelShutdownEvent.Reason
public let errorMessage: String
public let date: Date
public var action: Action {
switch reason {
case .stopped(let reason):
if reason == .userInitiated || reason == .userLogout || reason == .userSwitch {
return .signoutImmediately
} else {
return .retryThenSignout
}
case .networkSettingsApplyFailure, .invalidAdapterState:
return .retryThenSignout
case .connlibConnectFailure, .connlibDisconnected,
.badTunnelConfiguration, .tokenNotFound:
return .signoutImmediately
}
}
public var description: String {
"(\(reason)\(action == .signoutImmediately ? " (needs immediate signout)" : ""), error: '\(errorMessage)', date: \(date))"
}
public init(reason: TunnelShutdownEvent.Reason, errorMessage: String) {
self.reason = reason
self.errorMessage = errorMessage
self.date = Date()
}
public static func loadFromDisk() -> TunnelShutdownEvent? {
let fileURL = SharedAccess.tunnelShutdownEventFileURL
let fileManager = FileManager.default
guard fileManager.fileExists(atPath: fileURL.path) else {
return nil
}
guard let jsonData = try? Data(contentsOf: fileURL) else {
Self.logger.error("Could not read tunnel shutdown event from disk at: \(fileURL)")
return nil
}
guard let reason = try? JSONDecoder().decode(TunnelShutdownEvent.self, from: jsonData) else {
Self.logger.error("Error decoding tunnel shutdown event from disk at: \(fileURL)")
return nil
}
do {
try fileManager.removeItem(atPath: fileURL.path)
} catch {
Self.logger.error("Cannot remove tunnel shutdown event file at \(fileURL.path)")
}
return reason
}
public static func saveToDisk(reason: TunnelShutdownEvent.Reason, errorMessage: String) {
let fileURL = SharedAccess.tunnelShutdownEventFileURL
Self.logger.error("Saving tunnel shutdown event data to \(fileURL, privacy: .public)")
let tsEvent = TunnelShutdownEvent(
reason: reason,
errorMessage: errorMessage)
do {
try JSONEncoder().encode(tsEvent).write(to: fileURL)
} catch {
Self.logger.error(
"Error writing tunnel shutdown event data to disk to: \(fileURL, privacy: .public): \(error, privacy: .public)"
)
}
}
}
extension NEProviderStopReason: Codable {
}

View File

@@ -42,24 +42,33 @@ final class AppStore: ObservableObject {
}
private func handleLoginStatusChanged(_ loginStatus: AuthStore.LoginStatus) async {
logger.log("\(#function): login status = \(loginStatus)")
switch loginStatus {
case .signedIn:
do {
try await tunnel.start()
} catch {
logger.error("Error starting tunnel: \(String(describing: error))")
logger.error("\(#function): Error starting tunnel: \(String(describing: error))")
}
case .signedOut:
tunnel.stop()
do {
try await tunnel.stop()
} catch {
logger.error("\(#function): Error stopping tunnel: \(String(describing: error))")
}
case .uninitialized:
break
}
}
private func signOutAndStopTunnel() {
tunnel.stop()
Task {
try? await auth.signOut()
do {
try await tunnel.stop()
auth.signOut()
} catch {
logger.error("\(#function): Error stopping tunnel: \(String(describing: error))")
}
}
}
}

View File

@@ -7,6 +7,7 @@
import Combine
import Dependencies
import Foundation
import NetworkExtension
import OSLog
extension AuthStore: DependencyKey {
@@ -26,7 +27,7 @@ final class AuthStore: ObservableObject {
static let shared = AuthStore(tunnelStore: TunnelStore.shared)
enum LoginStatus {
enum LoginStatus: CustomStringConvertible {
case uninitialized
case signedOut(accountId: String?)
case signedIn(accountId: String, actorName: String)
@@ -38,6 +39,17 @@ final class AuthStore: ObservableObject {
case .signedIn(let accountId, _): return accountId
}
}
var description: String {
switch self {
case .uninitialized:
return "uninitialized"
case .signedOut(let accountId):
return "signedOut(accountId: \(accountId ?? "nil"))"
case .signedIn(let accountId, let actorName):
return "signedIn(accountId: \(accountId), actorName: \(actorName))"
}
}
}
@Dependency(\.keychain) private var keychain
@@ -48,6 +60,10 @@ final class AuthStore: ObservableObject {
private var cancellables = Set<AnyCancellable>()
@Published private(set) var loginStatus: LoginStatus
private var status: NEVPNStatus = .invalid
private static let maxReconnectionAttemptCount = 3
private var reconnectionAttemptsRemaining = maxReconnectionAttemptCount
private init(tunnelStore: TunnelStore) {
self.tunnelStore = tunnelStore
@@ -60,8 +76,53 @@ final class AuthStore: ObservableObject {
tunnelStore.$tunnelAuthStatus
.sink { [weak self] tunnelAuthStatus in
guard let self = self else { return }
logger.log("Tunnel auth status changed to: \(tunnelAuthStatus)")
Task {
self.loginStatus = await self.getLoginStatus(from: tunnelAuthStatus)
let loginStatus = await self.getLoginStatus(from: tunnelAuthStatus)
if tunnelStore.tunnelAuthStatus == tunnelAuthStatus {
// Make sure the tunnelAuthStatus hasn't changed while we were getting the login status
self.loginStatus = loginStatus
}
}
}
.store(in: &cancellables)
tunnelStore.$status
.sink { [weak self] status in
guard let self = self else { return }
Task {
if status == .disconnected {
self.logger.log("\(#function): Disconnected")
if let tsEvent = TunnelShutdownEvent.loadFromDisk() {
self.logger.log(
"\(#function): Tunnel shutdown event: \(tsEvent, privacy: .public)"
)
switch tsEvent.action {
case .signoutImmediately:
self.signOut()
case .retryThenSignout:
let shouldReconnect = (self.reconnectionAttemptsRemaining > 0)
self.reconnectionAttemptsRemaining = self.reconnectionAttemptsRemaining - 1
if shouldReconnect {
self.logger.log(
"\(#function): Will try to reconnect after 1 second (\(self.reconnectionAttemptsRemaining) attempts after this)"
)
DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) {
self.logger.log("\(#function): Trying to reconnect")
self.startTunnel()
}
} else {
self.signOut()
}
}
} else {
self.logger.log("\(#function): Tunnel shutdown event not found")
}
}
if status == .connected {
self.resetReconnectionAttemptsRemaining()
}
self.status = status
}
}
.store(in: &cancellables)
@@ -124,10 +185,11 @@ final class AuthStore: ObservableObject {
try await signIn(accountId: accountId)
}
func signOut() async throws {
func signOut() {
logger.trace("\(#function)")
guard case .signedIn = self.loginStatus else {
guard case .signedIn = self.tunnelStore.tunnelAuthStatus else {
logger.trace("\(#function): Not signed in, so can't signout.")
return
}
@@ -136,6 +198,29 @@ final class AuthStore: ObservableObject {
try await keychain.delete(tokenRef)
}
}
resetReconnectionAttemptsRemaining()
}
func startTunnel() {
logger.trace("\(#function)")
guard case .signedIn = self.tunnelStore.tunnelAuthStatus else {
logger.trace("\(#function): Not signed in, so can't start the tunnel.")
return
}
Task {
do {
try await tunnelStore.start()
} catch {
logger.error("\(#function): Error starting tunnel: \(String(describing: error))")
}
}
}
func resetReconnectionAttemptsRemaining() {
self.reconnectionAttemptsRemaining = Self.maxReconnectionAttemptCount
}
func tunnelAuthStatusForAccount(accountId: String) async -> TunnelAuthStatus {

View File

@@ -11,6 +11,7 @@ import OSLog
enum TunnelStoreError: Error {
case tunnelCouldNotBeStarted
case tunnelCouldNotBeStopped
}
public struct TunnelProviderKeys {
@@ -39,6 +40,7 @@ final class TunnelStore: ObservableObject {
private var tunnelObservingTasks: [Task<Void, Never>] = []
private var startTunnelContinuation: CheckedContinuation<(), Error>?
private var stopTunnelContinuation: CheckedContinuation<(), Error>?
private var cancellables = Set<AnyCancellable>()
init() {
@@ -70,6 +72,7 @@ final class TunnelStore: ObservableObject {
}
self.tunnel = tunnel
self.tunnelAuthStatus = tunnel.authStatus()
self.status = tunnel.connection.status
} else {
let tunnel = NETunnelProviderManager()
tunnel.localizedDescription = "Firezone"
@@ -91,11 +94,7 @@ final class TunnelStore: ObservableObject {
fatalError("Tunnel not initialized yet")
}
let wasConnected =
(tunnel.connection.status == .connected || tunnel.connection.status == .connecting)
if wasConnected {
stop()
}
try await stop()
try await tunnel.saveAuthStatus(tunnelAuthStatus)
self.tunnelAuthStatus = tunnelAuthStatus
@@ -106,11 +105,7 @@ final class TunnelStore: ObservableObject {
fatalError("Tunnel not initialized yet")
}
let wasConnected =
(tunnel.connection.status == .connected || tunnel.connection.status == .connecting)
if wasConnected {
stop()
}
try await stop()
try await tunnel.saveAdvancedSettings(advancedSettings)
}
@@ -164,15 +159,22 @@ final class TunnelStore: ObservableObject {
}
}
func stop() {
func stop() async throws {
guard let tunnel = tunnel else {
Self.logger.log("\(#function): TunnelStore is not initialized")
return
}
TunnelStore.logger.trace("\(#function)")
let session = castToSession(tunnel.connection)
session.stopTunnel()
let status = tunnel.connection.status
if status == .connected || status == .connecting {
let session = castToSession(tunnel.connection)
session.stopTunnel()
try await withCheckedThrowingContinuation { continuation in
self.stopTunnelContinuation = continuation
}
}
}
func stopAndSignOut() async throws -> Keychain.PersistentRef? {
@@ -182,8 +184,12 @@ final class TunnelStore: ObservableObject {
}
TunnelStore.logger.trace("\(#function)")
let session = castToSession(tunnel.connection)
session.stopTunnel()
let status = tunnel.connection.status
if status == .connected || status == .connecting {
let session = castToSession(tunnel.connection)
session.stopTunnel()
}
if case .signedIn(let authBaseURL, let accountId, let tokenReference) = self.tunnelAuthStatus {
try await saveAuthStatus(.signedOut(authBaseURL: authBaseURL, accountId: accountId))
@@ -279,6 +285,18 @@ final class TunnelStore: ObservableObject {
break
}
}
if let stopTunnelContinuation = self.stopTunnelContinuation {
switch status {
case .disconnected:
stopTunnelContinuation.resume(returning: ())
self.stopTunnelContinuation = nil
case .connected:
stopTunnelContinuation.resume(throwing: TunnelStoreError.tunnelCouldNotBeStopped)
self.stopTunnelContinuation = nil
default:
break
}
}
if status != .connected {
self.resources = DisplayableResources()
}
@@ -298,7 +316,7 @@ final class TunnelStore: ObservableObject {
}
}
enum TunnelAuthStatus {
enum TunnelAuthStatus: Equatable, CustomStringConvertible {
case tunnelUninitialized
case accountNotSetup
case signedOut(authBaseURL: URL, accountId: String)
@@ -321,6 +339,19 @@ enum TunnelAuthStatus {
return accountId
}
}
var description: String {
switch self {
case .tunnelUninitialized:
return "tunnel uninitialized"
case .accountNotSetup:
return "account not setup"
case .signedOut(let authBaseURL, let accountId):
return "signedOut(authBaseURL: \(authBaseURL), accountId: \(accountId))"
case .signedIn(let authBaseURL, let accountId, _):
return "signedIn(authBaseURL: \(authBaseURL), accountId: \(accountId))"
}
}
}
// MARK: - Extensions

View File

@@ -127,7 +127,7 @@
private lazy var resourcesUnavailableReasonMenuItem = createMenuItem(
menu,
title: "",
action: #selector(reconnectButtonTapped),
action: nil,
isHidden: true,
target: self
)
@@ -225,13 +225,7 @@
}
@objc private func signOutButtonTapped() {
Task {
do {
try await appStore?.auth.signOut()
} catch {
logger.error("error signing out: \(String(describing: error))")
}
}
appStore?.auth.signOut()
}
@objc private func settingsButtonTapped() {
@@ -315,6 +309,7 @@
case .signedOut:
signInMenuItem.title = "Sign In"
signInMenuItem.target = self
signInMenuItem.isEnabled = true
signOutMenuItem.isHidden = true
case .signedIn(_, let actorName):
signInMenuItem.title = actorName.isEmpty ? "Signed in" : "Signed in as \(actorName)"
@@ -360,16 +355,14 @@
resourcesUnavailableReasonMenuItem.target = nil
resourcesUnavailableReasonMenuItem.title = "Disconnecting…"
resourcesSeparatorMenuItem.isHidden = false
case (.signedIn, _):
// Ideally, this shouldn't happen, but it's better
// we handle this case, so that in case connlib errors out,
// the user is able to try to reconnect.
case (.signedIn, .disconnected), (.signedIn, .invalid), (.signedIn, _):
// We should never be in a state where the tunnel is
// down but the user is signed in, but we have
// code to handle it just for the sake of completion.
resourcesTitleMenuItem.isHidden = true
resourcesUnavailableMenuItem.isHidden = false
resourcesUnavailableReasonMenuItem.isHidden = false
resourcesUnavailableReasonMenuItem.target = self
resourcesUnavailableReasonMenuItem.isEnabled = true
resourcesUnavailableReasonMenuItem.title = "Reconnect"
resourcesUnavailableReasonMenuItem.title = "Disconnected"
resourcesSeparatorMenuItem.isHidden = false
}
}

View File

@@ -54,7 +54,7 @@ private enum AdapterState: CustomStringConvertible {
}
// Loosely inspired from WireGuardAdapter from WireGuardKit
public class Adapter {
class Adapter {
typealias StartTunnelCompletionHandler = ((AdapterError?) -> Void)
typealias StopTunnelCompletionHandler = (() -> Void)
@@ -67,7 +67,7 @@ public class Adapter {
private var networkSettings: NetworkSettings?
/// Packet tunnel provider.
private weak var packetTunnelProvider: NEPacketTunnelProvider?
private weak var packetTunnelProvider: PacketTunnelProvider?
/// Network routes monitor.
private var networkMonitor: NWPathMonitor?
@@ -92,9 +92,9 @@ public class Adapter {
private let logFilter: String
private let connlibLogFolderPath: String
public init(
init(
controlPlaneURLString: String, token: String,
logFilter: String, packetTunnelProvider: NEPacketTunnelProvider
logFilter: String, packetTunnelProvider: PacketTunnelProvider
) {
self.controlPlaneURLString = controlPlaneURLString
self.token = token
@@ -126,6 +126,9 @@ public class Adapter {
self.logger.log("Adapter.start")
guard case .stoppedTunnel = self.state else {
packetTunnelProvider?.handleTunnelShutdown(
dueTo: .invalidAdapterState,
errorMessage: "Adapter is in invalid state")
completionHandler(.invalidState)
return
}
@@ -146,20 +149,26 @@ public class Adapter {
)
} catch let error {
self.logger.error("Adapter.start: Error: \(error, privacy: .public)")
packetTunnelProvider?.handleTunnelShutdown(
dueTo: .connlibConnectFailure,
errorMessage: error.localizedDescription)
self.state = .stoppedTunnel
completionHandler(AdapterError.connlibConnectError(error))
}
}
}
/// Stop the tunnel
public func stop(completionHandler: @escaping () -> Void) {
public func stop(reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
workQueue.async { [weak self] in
guard let self = self else { return }
self.logger.log("Adapter.stop")
packetTunnelProvider?.handleTunnelShutdown(
dueTo: .stopped(reason),
errorMessage: "\(reason)")
switch self.state {
case .stoppedTunnel, .stoppingTunnel:
break
@@ -359,6 +368,9 @@ extension Adapter: CallbackHandlerDelegate {
}
networkSettings.apply(on: packetTunnelProvider, logger: self.logger) { error in
if let error = error {
packetTunnelProvider.handleTunnelShutdown(
dueTo: .networkSettingsApplyFailure,
errorMessage: error.localizedDescription)
onStarted?(AdapterError.setNetworkSettings(error))
self.state = .stoppedTunnel
} else {
@@ -449,7 +461,7 @@ extension Adapter: CallbackHandlerDelegate {
workQueue.async { [weak self] in
guard let self = self else { return }
self.logger.log("Adapter.onDisconnect")
self.logger.log("Adapter.onDisconnect: \(error ?? "No error", privacy: .public)")
if let errorMessage = error {
self.logger.error(
"Connlib disconnected with unrecoverable error: \(errorMessage, privacy: .public)")
@@ -466,6 +478,9 @@ extension Adapter: CallbackHandlerDelegate {
case .stoppedTunnelTemporarily:
self.state = .stoppedTunnel
default:
packetTunnelProvider?.handleTunnelShutdown(
dueTo: .connlibDisconnected,
errorMessage: errorMessage)
self.packetTunnelProvider?.cancelTunnelWithError(
AdapterError.connlibFatalError(errorMessage))
self.state = .stoppedTunnel
@@ -480,7 +495,6 @@ extension Adapter: CallbackHandlerDelegate {
onStopped?()
self.state = .stoppedTunnelTemporarily
default:
// This should not happen
self.state = .stoppedTunnel
}
}

View File

@@ -28,6 +28,9 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
guard let controlPlaneURLString = protocolConfiguration.serverAddress else {
Self.logger.error("serverAddress is missing")
self.handleTunnelShutdown(
dueTo: .badTunnelConfiguration,
errorMessage: "serverAddress is missing")
completionHandler(
PacketTunnelProviderError.savedProtocolConfigurationIsInvalid("serverAddress"))
return
@@ -35,6 +38,9 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
guard let tokenRef = protocolConfiguration.passwordReference else {
Self.logger.error("passwordReference is missing")
self.handleTunnelShutdown(
dueTo: .badTunnelConfiguration,
errorMessage: "passwordReference is missing")
completionHandler(
PacketTunnelProviderError.savedProtocolConfigurationIsInvalid("passwordReference"))
return
@@ -45,6 +51,9 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
guard let connlibLogFilter = providerConfig?[TunnelProviderKeys.keyConnlibLogFilter] as? String
else {
Self.logger.error("connlibLogFilter is missing")
self.handleTunnelShutdown(
dueTo: .badTunnelConfiguration,
errorMessage: "connlibLogFilter is missing")
completionHandler(
PacketTunnelProviderError.savedProtocolConfigurationIsInvalid("connlibLogFilter"))
return
@@ -53,6 +62,9 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
Task {
let keychain = Keychain()
guard let token = await keychain.load(persistentRef: tokenRef) else {
self.handleTunnelShutdown(
dueTo: .tokenNotFound,
errorMessage: "Token not found in keychain")
completionHandler(PacketTunnelProviderError.tokenNotFoundInKeychain)
return
}
@@ -74,8 +86,11 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
}
override func stopTunnel(with _: NEProviderStopReason, completionHandler: @escaping () -> Void) {
adapter?.stop {
override func stopTunnel(
with reason: NEProviderStopReason, completionHandler: @escaping () -> Void
) {
Self.logger.log("stopTunnel: Reason: \(reason)")
adapter?.stop(reason: reason) {
completionHandler()
#if os(macOS)
// HACK: This is a filthy hack to work around Apple bug 32073323
@@ -91,4 +106,33 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
completionHandler?(displayableResources?.toData())
}
}
func handleTunnelShutdown(dueTo reason: TunnelShutdownEvent.Reason, errorMessage: String) {
TunnelShutdownEvent.saveToDisk(reason: reason, errorMessage: errorMessage)
}
}
extension NEProviderStopReason: CustomStringConvertible {
public var description: String {
switch self {
case .none: return "None"
case .userInitiated: return "User-initiated"
case .providerFailed: return "Provider failed"
case .noNetworkAvailable: return "No network available"
case .unrecoverableNetworkChange: return "Unrecoverable network change"
case .providerDisabled: return "Provider disabled"
case .authenticationCanceled: return "Authentication cancelled"
case .configurationFailed: return "Configuration failed"
case .idleTimeout: return "Idle timeout"
case .configurationDisabled: return "Configuration disabled"
case .configurationRemoved: return "Configuration removed"
case .superceded: return "Superceded"
case .userLogout: return "User logged out"
case .userSwitch: return "User switched"
case .connectionFailed: return "Connection failed"
case .sleep: return "Sleep"
case .appUpdate: return "App update"
@unknown default: return "Unknown"
}
}
}