diff --git a/swift/apple/FirezoneKit/Sources/FirezoneKit/AuthClient/AuthClient.swift b/swift/apple/FirezoneKit/Sources/FirezoneKit/AuthClient/AuthClient.swift index 63e60412b..caabea83b 100644 --- a/swift/apple/FirezoneKit/Sources/FirezoneKit/AuthClient/AuthClient.swift +++ b/swift/apple/FirezoneKit/Sources/FirezoneKit/AuthClient/AuthClient.swift @@ -10,8 +10,10 @@ import Foundation enum AuthClientError: Error { case invalidCallbackURL(URL?) + case invalidStateReturnedInCallback(expected: String, got: String) case authResponseError(Error) case sessionFailure(Error) + case randomNumberGenerationFailure(errorStatus: Int32) } struct AuthClient: Sendable { @@ -47,10 +49,17 @@ private final class WebAuthenticationSession: NSObject, @MainActor func signIn(_ host: URL) async throws -> AuthResponse { - try await withCheckedThrowingContinuation { continuation in - let callbackURLScheme = "firezone" + let statePassedToPortal = try Self.createRandomHexString(byteCount: 32) + let nonce = try Self.createRandomHexString(byteCount: 32) + let url = + host + .appendingQueryItem(URLQueryItem(name: "state", value: statePassedToPortal)) + .appendingQueryItem(URLQueryItem(name: "nonce", value: nonce)) + .appendingQueryItem(URLQueryItem(name: "as", value: "client")) + return try await withCheckedThrowingContinuation { continuation in + let callbackURLScheme = "firezone-fd002021111" let session = ASWebAuthenticationSession( - url: host.appendingQueryItem(URLQueryItem(name: "client_platform", value: "apple")), + url: url, callbackURLScheme: callbackURLScheme ) { callbackURL, error in if let error { @@ -64,9 +73,26 @@ private final class WebAuthenticationSession: NSObject, } guard - let token = URLComponents(url: callbackURL, resolvingAgainstBaseURL: false)? + let stateInCallback = URLComponents(url: callbackURL, resolvingAgainstBaseURL: false)? .queryItems? - .first(where: { $0.name == "client_auth_token" })? + .first(where: { $0.name == "state" })? + .value + else { + continuation.resume(throwing: AuthClientError.invalidCallbackURL(callbackURL)) + return + } + + guard Self.areStringsEqualConstantTime(statePassedToPortal, stateInCallback) else { + continuation.resume( + throwing: AuthClientError.invalidStateReturnedInCallback( + expected: statePassedToPortal, got: stateInCallback)) + return + } + + guard + let fragment = URLComponents(url: callbackURL, resolvingAgainstBaseURL: false)? + .queryItems? + .first(where: { $0.name == "fragment" })? .value else { continuation.resume(throwing: AuthClientError.invalidCallbackURL(callbackURL)) @@ -85,6 +111,7 @@ private final class WebAuthenticationSession: NSObject, return } + let token = nonce + fragment let authResponse = AuthResponse(portalURL: host, token: token, actorName: actorName) continuation.resume(returning: authResponse) } @@ -100,6 +127,34 @@ private final class WebAuthenticationSession: NSObject, } } + static func createRandomHexString(byteCount: Int) throws -> String { + var bytes = [Int8](repeating: 0, count: byteCount) + let status = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes) + + guard status == errSecSuccess else { + throw AuthClientError.randomNumberGenerationFailure(errorStatus: status) + } + + return bytes.map { String(format: "%02hhx", $0) }.joined() + } + + static func areStringsEqualConstantTime(_ string1: String, _ string2: String) -> Bool { + let charArray1 = string1.utf8CString + let charArray2 = string2.utf8CString + + if charArray1.count != charArray2.count { + return false + } + + var result: CChar = 0 + for (char1, char2) in zip(charArray1, charArray2) { + // Iff all the XORs result in 0, then the strings are equal + result |= (char1 ^ char2) + } + + return (result == 0) + } + func presentationAnchor(for _: ASWebAuthenticationSession) -> ASPresentationAnchor { ASPresentationAnchor() }