diff --git a/Sources/CodexBar/CodexbarApp.swift b/Sources/CodexBar/CodexbarApp.swift index f5097b0e4..febb19ddc 100644 --- a/Sources/CodexBar/CodexbarApp.swift +++ b/Sources/CodexBar/CodexbarApp.swift @@ -81,7 +81,10 @@ struct CodexBarApp: App { store: self.store, updater: self.appDelegate.updaterController, selection: self.preferencesSelection, - managedCodexAccountCoordinator: self.managedCodexAccountCoordinator) + managedCodexAccountCoordinator: self.managedCodexAccountCoordinator, + runProviderLoginFlow: { provider in + await self.appDelegate.runProviderLoginFlow(provider) + }) } .defaultSize(width: PreferencesTab.general.preferredWidth, height: PreferencesTab.general.preferredHeight) .windowResizability(.contentSize) @@ -300,6 +303,12 @@ final class AppDelegate: NSObject, NSApplicationDelegate { TTYCommandRunner.terminateActiveProcessesForAppShutdown() } + func runProviderLoginFlow(_ provider: UsageProvider) async { + self.ensureStatusController() + guard let statusController else { return } + await statusController.runLoginFlowFromSettings(provider: provider) + } + /// Use the classic (non-Liquid Glass) app icon on macOS versions before 26. private func configureAppIconForMacOSVersion() { if #unavailable(macOS 26) { diff --git a/Sources/CodexBar/PreferencesProviderDetailView.swift b/Sources/CodexBar/PreferencesProviderDetailView.swift index edcc1d0dc..498e62eee 100644 --- a/Sources/CodexBar/PreferencesProviderDetailView.swift +++ b/Sources/CodexBar/PreferencesProviderDetailView.swift @@ -11,6 +11,7 @@ struct ProviderDetailView: View { let settingsPickers: [ProviderSettingsPickerDescriptor] let settingsToggles: [ProviderSettingsToggleDescriptor] let settingsFields: [ProviderSettingsFieldDescriptor] + let settingsActions: [ProviderSettingsActionsDescriptor] let settingsTokenAccounts: ProviderSettingsTokenAccountsDescriptor? let errorDisplay: ProviderErrorDisplay? @Binding var isErrorExpanded: Bool @@ -28,6 +29,7 @@ struct ProviderDetailView: View { settingsPickers: [ProviderSettingsPickerDescriptor], settingsToggles: [ProviderSettingsToggleDescriptor], settingsFields: [ProviderSettingsFieldDescriptor], + settingsActions: [ProviderSettingsActionsDescriptor] = [], settingsTokenAccounts: ProviderSettingsTokenAccountsDescriptor?, errorDisplay: ProviderErrorDisplay?, isErrorExpanded: Binding, @@ -44,6 +46,7 @@ struct ProviderDetailView: View { self.settingsPickers = settingsPickers self.settingsToggles = settingsToggles self.settingsFields = settingsFields + self.settingsActions = settingsActions self.settingsTokenAccounts = settingsTokenAccounts self.errorDisplay = errorDisplay self._isErrorExpanded = isErrorExpanded @@ -118,6 +121,9 @@ struct ProviderDetailView: View { ForEach(self.settingsFields) { field in ProviderSettingsFieldRowView(field: field) } + ForEach(self.settingsActions) { descriptor in + ProviderSettingsActionsRowView(descriptor: descriptor) + } } } @@ -143,6 +149,7 @@ struct ProviderDetailView: View { private var hasSettings: Bool { !self.settingsPickers.isEmpty || !self.settingsFields.isEmpty || + !self.settingsActions.isEmpty || self.settingsTokenAccounts != nil } diff --git a/Sources/CodexBar/PreferencesProviderSettingsRows.swift b/Sources/CodexBar/PreferencesProviderSettingsRows.swift index 414f41c55..97d8465bb 100644 --- a/Sources/CodexBar/PreferencesProviderSettingsRows.swift +++ b/Sources/CodexBar/PreferencesProviderSettingsRows.swift @@ -200,6 +200,40 @@ struct ProviderSettingsFieldRowView: View { } } +@MainActor +struct ProviderSettingsActionsRowView: View { + let descriptor: ProviderSettingsActionsDescriptor + + var body: some View { + VStack(alignment: .leading, spacing: 8) { + Text(self.descriptor.title) + .font(.subheadline.weight(.semibold)) + + if !self.descriptor.subtitle.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + Text(self.descriptor.subtitle) + .font(.footnote) + .foregroundStyle(.secondary) + .fixedSize(horizontal: false, vertical: true) + } + + let actions = self.descriptor.actions.filter { $0.isVisible?() ?? true } + if !actions.isEmpty { + HStack(spacing: 10) { + ForEach(actions) { action in + Button(action.title) { + Task { @MainActor in + await action.perform() + } + } + .applyProviderSettingsButtonStyle(action.style) + .controlSize(.small) + } + } + } + } + } +} + @MainActor struct ProviderSettingsTokenAccountsRowView: View { let descriptor: ProviderSettingsTokenAccountsDescriptor diff --git a/Sources/CodexBar/PreferencesProvidersPane.swift b/Sources/CodexBar/PreferencesProvidersPane.swift index 53943cf53..ff326ed65 100644 --- a/Sources/CodexBar/PreferencesProvidersPane.swift +++ b/Sources/CodexBar/PreferencesProvidersPane.swift @@ -8,6 +8,7 @@ struct ProvidersPane: View { @Bindable var store: UsageStore let managedCodexAccountCoordinator: ManagedCodexAccountCoordinator let codexAmbientLoginRunner: any CodexAmbientLoginRunning + let runProviderLoginFlow: @MainActor (UsageProvider) async -> Void @State private var expandedErrors: Set = [] @State private var settingsStatusTextByID: [String: String] = [:] @State private var settingsLastAppActiveRunAtByID: [String: Date] = [:] @@ -24,12 +25,14 @@ struct ProvidersPane: View { settings: SettingsStore, store: UsageStore, managedCodexAccountCoordinator: ManagedCodexAccountCoordinator = ManagedCodexAccountCoordinator(), - codexAmbientLoginRunner: any CodexAmbientLoginRunning = DefaultCodexAmbientLoginRunner()) + codexAmbientLoginRunner: any CodexAmbientLoginRunning = DefaultCodexAmbientLoginRunner(), + runProviderLoginFlow: @escaping @MainActor (UsageProvider) async -> Void = { _ in }) { self.settings = settings self.store = store self.managedCodexAccountCoordinator = managedCodexAccountCoordinator self.codexAmbientLoginRunner = codexAmbientLoginRunner + self.runProviderLoginFlow = runProviderLoginFlow } var body: some View { @@ -54,6 +57,7 @@ struct ProvidersPane: View { settingsPickers: self.extraSettingsPickers(for: provider), settingsToggles: self.extraSettingsToggles(for: provider), settingsFields: self.extraSettingsFields(for: provider), + settingsActions: self.extraSettingsActions(for: provider), settingsTokenAccounts: self.tokenAccountDescriptor(for: provider), errorDisplay: self.providerErrorDisplay(provider), isErrorExpanded: self.expandedBinding(for: provider), @@ -309,6 +313,13 @@ struct ProvidersPane: View { .filter { $0.isVisible?() ?? true } } + private func extraSettingsActions(for provider: UsageProvider) -> [ProviderSettingsActionsDescriptor] { + guard let impl = ProviderCatalog.implementation(for: provider) else { return [] } + let context = self.makeSettingsContext(provider: provider) + return impl.settingsActions(context: context) + .filter { $0.isVisible?() ?? true } + } + func tokenAccountDescriptor(for provider: UsageProvider) -> ProviderSettingsTokenAccountsDescriptor? { guard let support = TokenAccountSupportCatalog.support(for: provider) else { return nil } let context = self.makeSettingsContext(provider: provider) @@ -403,6 +414,9 @@ struct ProvidersPane: View { }, requestConfirmation: { confirmation in self.activeConfirmation = ProviderSettingsConfirmationState(confirmation: confirmation) + }, + runLoginFlow: { + await self.runProviderLoginFlow(provider) }) } diff --git a/Sources/CodexBar/PreferencesView.swift b/Sources/CodexBar/PreferencesView.swift index 3f2e58ec9..52beec9cb 100644 --- a/Sources/CodexBar/PreferencesView.swift +++ b/Sources/CodexBar/PreferencesView.swift @@ -1,4 +1,5 @@ import AppKit +import CodexBarCore import SwiftUI enum PreferencesTab: String, Hashable { @@ -29,6 +30,7 @@ struct PreferencesView: View { let updater: UpdaterProviding @Bindable var selection: PreferencesSelection let managedCodexAccountCoordinator: ManagedCodexAccountCoordinator + let runProviderLoginFlow: @MainActor (UsageProvider) async -> Void @State private var contentWidth: CGFloat = PreferencesTab.general.preferredWidth @State private var contentHeight: CGFloat = PreferencesTab.general.preferredHeight @@ -37,13 +39,15 @@ struct PreferencesView: View { store: UsageStore, updater: UpdaterProviding, selection: PreferencesSelection, - managedCodexAccountCoordinator: ManagedCodexAccountCoordinator = ManagedCodexAccountCoordinator()) + managedCodexAccountCoordinator: ManagedCodexAccountCoordinator = ManagedCodexAccountCoordinator(), + runProviderLoginFlow: @escaping @MainActor (UsageProvider) async -> Void = { _ in }) { self.settings = settings self.store = store self.updater = updater self.selection = selection self.managedCodexAccountCoordinator = managedCodexAccountCoordinator + self.runProviderLoginFlow = runProviderLoginFlow } var body: some View { @@ -55,7 +59,8 @@ struct PreferencesView: View { ProvidersPane( settings: self.settings, store: self.store, - managedCodexAccountCoordinator: self.managedCodexAccountCoordinator) + managedCodexAccountCoordinator: self.managedCodexAccountCoordinator, + runProviderLoginFlow: self.runProviderLoginFlow) .tabItem { Label("Providers", systemImage: "square.grid.2x2") } .tag(PreferencesTab.providers) diff --git a/Sources/CodexBar/Providers/Antigravity/AntigravityLoginFlow.swift b/Sources/CodexBar/Providers/Antigravity/AntigravityLoginFlow.swift index e41aa8fe6..74133bccb 100644 --- a/Sources/CodexBar/Providers/Antigravity/AntigravityLoginFlow.swift +++ b/Sources/CodexBar/Providers/Antigravity/AntigravityLoginFlow.swift @@ -3,9 +3,28 @@ import CodexBarCore @MainActor extension StatusItemController { func runAntigravityLoginFlow() async { + let store = self.store + let phaseHandler: @Sendable (AntigravityLoginRunner.Phase) -> Void = { [weak self] phase in + Task { @MainActor in + switch phase { + case .waitingBrowser: + self?.loginPhase = .waitingBrowser + } + } + } + let result = await AntigravityLoginRunner.run(onPhaseChange: phaseHandler) { + Task { @MainActor in + await store.refresh() + CodexBarLog.logger(LogCategories.login).info("Auto-refreshed after Antigravity auth") + } + } + guard !Task.isCancelled else { return } self.loginPhase = .idle - self.presentLoginAlert( - title: "Antigravity login is managed in the app", - message: "Open Antigravity to sign in, then refresh CodexBar.") + self.presentAntigravityLoginResult(result) + let outcome = self.describe(result.outcome) + self.loginLogger.info("Antigravity login", metadata: ["outcome": outcome]) + if case .success = result.outcome { + self.postLoginNotification(for: .antigravity) + } } } diff --git a/Sources/CodexBar/Providers/Antigravity/AntigravityLoginRunner.swift b/Sources/CodexBar/Providers/Antigravity/AntigravityLoginRunner.swift new file mode 100644 index 000000000..9d5d5e1d7 --- /dev/null +++ b/Sources/CodexBar/Providers/Antigravity/AntigravityLoginRunner.swift @@ -0,0 +1,485 @@ +import AppKit +import CodexBarCore +import Darwin +import Foundation +import Network + +enum AntigravityLoginRunner { + enum Phase { + case waitingBrowser + } + + struct Result { + enum Outcome { + case success(String?) + case cancelled + case timedOut + case launchFailed(String) + case failed(String) + } + + let outcome: Outcome + } + + static func run( + timeout: TimeInterval = 120, + onPhaseChange: (@Sendable (Phase) -> Void)? = nil, + onCredentialsCreated: (@Sendable () -> Void)? = nil) async -> Result + { + guard let oauthClient = AntigravityOAuthConfig.resolvedClient() else { + return Result(outcome: .failed(AntigravityOAuthConfig.missingCredentialsMessage)) + } + + let state = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let server = AntigravityLoopbackServer(state: state) + + do { + let callbackURL = try await server.start() + let authURL = try Self.makeAuthorizationURL( + redirectURL: callbackURL, + state: state, + oauthClient: oauthClient) + onPhaseChange?(.waitingBrowser) + + let opened = await MainActor.run { + NSWorkspace.shared.open(authURL) + } + guard opened else { + server.stop() + return Result(outcome: .launchFailed(authURL.absoluteString)) + } + + let callback = try await withThrowingTaskGroup(of: AntigravityOAuthCallback.self) { group in + group.addTask { + try await server.waitForCallback() + } + group.addTask { + try await Task.sleep(for: .seconds(timeout)) + server.cancelCallbackWait(with: AntigravityLoginError.timedOut) + throw AntigravityLoginError.timedOut + } + defer { group.cancelAll() } + return try await group.next().unsafelyUnwrapped + } + server.stop() + + if let error = callback.error?.trimmingCharacters(in: .whitespacesAndNewlines), !error.isEmpty { + if error == "access_denied" { + return Result(outcome: .cancelled) + } + return Result(outcome: .failed(error)) + } + + guard callback.returnedState == state else { + return Result(outcome: .failed("Google login state mismatch.")) + } + guard let code = callback.code?.trimmingCharacters(in: .whitespacesAndNewlines), !code.isEmpty else { + return Result(outcome: .failed("Google login did not return an authorization code.")) + } + + let tokenResponse = try await Self.exchangeCodeForTokens( + code: code, + redirectURL: callbackURL, + oauthClient: oauthClient) + let email = try await Self.fetchUserEmail(accessToken: tokenResponse.accessToken) + let credentials = AntigravityOAuthCredentials( + accessToken: tokenResponse.accessToken, + refreshToken: tokenResponse.refreshToken, + expiryDate: Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)), + idToken: tokenResponse.idToken, + email: email, + projectID: nil, + clientID: oauthClient.clientID, + clientSecret: oauthClient.clientSecret) + try AntigravityOAuthCredentialsStore().save(credentials) + onCredentialsCreated?() + return Result(outcome: .success(email)) + } catch is CancellationError { + server.stop() + return Result(outcome: .cancelled) + } catch AntigravityLoginError.timedOut { + server.stop() + return Result(outcome: .timedOut) + } catch AntigravityLoginError.launchFailed(let message) { + server.stop() + return Result(outcome: .launchFailed(message)) + } catch { + server.stop() + return Result(outcome: .failed(error.localizedDescription)) + } + } + + private static func makeAuthorizationURL( + redirectURL: URL, + state: String, + oauthClient: AntigravityOAuthClient) throws -> URL + { + guard var components = URLComponents(url: AntigravityOAuthConfig.authURL, resolvingAgainstBaseURL: false) else { + throw AntigravityLoginError.invalidAuthorizationURL + } + components.queryItems = [ + URLQueryItem(name: "client_id", value: oauthClient.clientID), + URLQueryItem(name: "redirect_uri", value: redirectURL.absoluteString), + URLQueryItem(name: "response_type", value: "code"), + URLQueryItem(name: "scope", value: AntigravityOAuthConfig.scopes.joined(separator: " ")), + URLQueryItem(name: "access_type", value: "offline"), + URLQueryItem(name: "prompt", value: "consent"), + URLQueryItem(name: "state", value: state), + ] + guard let url = components.url else { + throw AntigravityLoginError.invalidAuthorizationURL + } + return url + } + + private static func exchangeCodeForTokens( + code: String, + redirectURL: URL, + oauthClient: AntigravityOAuthClient) async throws -> TokenResponse + { + var request = URLRequest(url: AntigravityOAuthConfig.tokenURL) + request.httpMethod = "POST" + request.timeoutInterval = 30 + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + request.httpBody = Self.formBody([ + "code": code, + "client_id": oauthClient.clientID, + "client_secret": oauthClient.clientSecret, + "redirect_uri": redirectURL.absoluteString, + "grant_type": "authorization_code", + ]) + + let (data, response) = try await URLSession.shared.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AntigravityLoginError.failed("Invalid token response.") + } + guard httpResponse.statusCode == 200 else { + let message = String(data: data, encoding: .utf8)?.trimmingCharacters(in: .whitespacesAndNewlines) + ?? "HTTP \(httpResponse.statusCode)" + throw AntigravityLoginError.failed(message) + } + do { + return try JSONDecoder().decode(TokenResponse.self, from: data) + } catch { + throw AntigravityLoginError.failed("Could not decode token response.") + } + } + + private static func fetchUserEmail(accessToken: String) async throws -> String? { + var request = URLRequest(url: AntigravityOAuthConfig.userInfoURL) + request.timeoutInterval = 15 + request.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") + + do { + let (data, response) = try await URLSession.shared.data(for: request) + guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else { + return nil + } + let userInfo = try JSONDecoder().decode(UserInfoResponse.self, from: data) + let email = userInfo.email?.trimmingCharacters(in: .whitespacesAndNewlines) + return (email?.isEmpty == false) ? email : nil + } catch { + return nil + } + } + + private static func formBody(_ values: [String: String]) -> Data? { + values + .map { key, value in + let encodedKey = key.addingPercentEncoding(withAllowedCharacters: .urlQueryValueAllowed) ?? key + let encodedValue = value.addingPercentEncoding(withAllowedCharacters: .urlQueryValueAllowed) ?? value + return "\(encodedKey)=\(encodedValue)" + } + .joined(separator: "&") + .data(using: .utf8) + } + +} + +private enum AntigravityLoginError: LocalizedError { + case invalidAuthorizationURL + case timedOut + case launchFailed(String) + case failed(String) + + var errorDescription: String? { + switch self { + case .invalidAuthorizationURL: + "Could not build the Antigravity login URL." + case .timedOut: + "Antigravity login timed out." + case let .launchFailed(message): + message + case let .failed(message): + message + } + } +} + +private struct TokenResponse: Decodable { + let accessToken: String + let refreshToken: String? + let expiresIn: Int + let idToken: String? + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case refreshToken = "refresh_token" + case expiresIn = "expires_in" + case idToken = "id_token" + } +} + +private struct UserInfoResponse: Decodable { + let email: String? +} + +private struct AntigravityOAuthCallback: Sendable { + let code: String? + let returnedState: String? + let error: String? +} + +private final class AntigravityLoopbackServer: @unchecked Sendable { + private let expectedState: String + private let queue = DispatchQueue(label: "codexbar.antigravity.oauth") + private let lock = NSLock() + private var listener: NWListener? + private var readyContinuation: CheckedContinuation? + private var callbackContinuation: CheckedContinuation? + private var pendingCallbackResult: Result? + private var completed = false + + init(state: String) { + self.expectedState = state + } + + func start() async throws -> URL { + let port = try Self.findAvailablePort() + guard let endpointPort = NWEndpoint.Port(rawValue: port) else { + throw AntigravityLoginError.failed("Could not reserve a local callback port.") + } + let listener = try NWListener(using: .tcp, on: endpointPort) + self.listener = listener + listener.newConnectionHandler = { [weak self] connection in + self?.handle(connection) + } + + return try await withCheckedThrowingContinuation { continuation in + self.readyContinuation = continuation + listener.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .ready: + let url = URL(string: "http://127.0.0.1:\(port)/callback")! + self.finishReady(with: .success(url)) + case .failed(let error): + self.finishReady(with: .failure(error)) + self.finishCallback(with: .failure(error)) + default: + break + } + } + listener.start(queue: self.queue) + } + } + + func waitForCallback() async throws -> AntigravityOAuthCallback { + try await withCheckedThrowingContinuation { continuation in + self.lock.lock() + defer { self.lock.unlock() } + if let pending = self.pendingCallbackResult { + self.pendingCallbackResult = nil + switch pending { + case .success(let callback): + continuation.resume(returning: callback) + case .failure(let error): + continuation.resume(throwing: error) + } + return + } + self.callbackContinuation = continuation + } + } + + func stop() { + self.listener?.cancel() + self.listener = nil + } + + func cancelCallbackWait(with error: Error) { + self.stop() + self.finishCallback(with: .failure(error)) + } + + private func handle(_ connection: NWConnection) { + connection.start(queue: self.queue) + self.receive(on: connection, accumulated: Data()) + } + + private func receive(on connection: NWConnection, accumulated: Data) { + connection.receive(minimumIncompleteLength: 1, maximumLength: 65_536) { [weak self] data, _, isComplete, error in + guard let self else { return } + if let error { + self.finishCallback(with: .failure(error)) + connection.cancel() + return + } + + var buffer = accumulated + if let data { + buffer.append(data) + } + + let headerMarker = Data("\r\n\r\n".utf8) + if buffer.range(of: headerMarker) == nil, !isComplete { + self.receive(on: connection, accumulated: buffer) + return + } + + let callback = self.parseCallback(from: buffer) + let response = self.httpResponse(for: callback) + connection.send(content: response, completion: .contentProcessed { _ in + connection.cancel() + }) + self.finishCallback(with: .success(callback)) + } + } + + private func parseCallback(from data: Data) -> AntigravityOAuthCallback { + guard let request = String(data: data, encoding: .utf8), + let line = request.components(separatedBy: "\r\n").first + else { + return AntigravityOAuthCallback(code: nil, returnedState: nil, error: "Invalid callback request.") + } + + let parts = line.split(separator: " ") + guard parts.count >= 2, + let url = URL(string: "http://127.0.0.1\(parts[1])"), + let components = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { + return AntigravityOAuthCallback(code: nil, returnedState: nil, error: "Invalid callback URL.") + } + + let code = components.queryItems?.first(where: { $0.name == "code" })?.value + let returnedState = components.queryItems?.first(where: { $0.name == "state" })?.value + let error = components.queryItems?.first(where: { $0.name == "error" })?.value + + guard components.path == "/callback" else { + return AntigravityOAuthCallback(code: nil, returnedState: returnedState, error: "Unexpected callback path.") + } + if let returnedState, returnedState != self.expectedState { + return AntigravityOAuthCallback(code: code, returnedState: returnedState, error: "State mismatch.") + } + return AntigravityOAuthCallback(code: code, returnedState: returnedState, error: error) + } + + private func httpResponse(for callback: AntigravityOAuthCallback) -> Data { + let success = callback.error == nil && callback.code?.isEmpty == false + let status = success ? "200 OK" : "400 Bad Request" + let title = success ? "Login Successful" : "Login Failed" + let detail = success + ? "You can close this window and return to CodexBar." + : "You can close this window and try again." + let html = """ + + +

\(title)

+

\(detail)

+ + + """ + let body = Data(html.utf8) + let header = """ + HTTP/1.1 \(status)\r + Content-Type: text/html; charset=utf-8\r + Content-Length: \(body.count)\r + Connection: close\r + \r + """ + var response = Data(header.utf8) + response.append(body) + return response + } + + private func finishReady(with result: Result) { + self.lock.lock() + let continuation = self.readyContinuation + self.readyContinuation = nil + self.lock.unlock() + switch result { + case .success(let url): + continuation?.resume(returning: url) + case .failure(let error): + continuation?.resume(throwing: error) + } + } + + private func finishCallback(with result: Result) { + self.lock.lock() + guard !self.completed else { + self.lock.unlock() + return + } + self.completed = true + let continuation = self.callbackContinuation + self.callbackContinuation = nil + if continuation == nil { + self.pendingCallbackResult = result + } + self.lock.unlock() + guard let continuation else { return } + switch result { + case .success(let callback): + continuation.resume(returning: callback) + case .failure(let error): + continuation.resume(throwing: error) + } + } + + private static func findAvailablePort() throws -> UInt16 { + let socketFD = socket(AF_INET, Int32(SOCK_STREAM), 0) + guard socketFD >= 0 else { + throw AntigravityLoginError.failed("Could not create a local callback socket.") + } + defer { close(socketFD) } + + var value: Int32 = 1 + setsockopt(socketFD, SOL_SOCKET, SO_REUSEADDR, &value, socklen_t(MemoryLayout.size)) + + var address = sockaddr_in() + address.sin_len = UInt8(MemoryLayout.stride) + address.sin_family = sa_family_t(AF_INET) + address.sin_port = in_port_t(0).bigEndian + address.sin_addr = in_addr(s_addr: inet_addr("127.0.0.1")) + + let bindResult = withUnsafePointer(to: &address) { pointer in + pointer.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPointer in + bind(socketFD, sockaddrPointer, socklen_t(MemoryLayout.stride)) + } + } + guard bindResult == 0 else { + throw AntigravityLoginError.failed("Could not bind a local callback port.") + } + + var boundAddress = sockaddr_in() + var length = socklen_t(MemoryLayout.stride) + let nameResult = withUnsafeMutablePointer(to: &boundAddress) { pointer in + pointer.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockaddrPointer in + getsockname(socketFD, sockaddrPointer, &length) + } + } + guard nameResult == 0 else { + throw AntigravityLoginError.failed("Could not inspect the callback port.") + } + return UInt16(bigEndian: boundAddress.sin_port) + } +} + +private extension CharacterSet { + static let urlQueryValueAllowed: CharacterSet = { + var allowed = CharacterSet.urlQueryAllowed + allowed.remove(charactersIn: "+&=") + return allowed + }() +} diff --git a/Sources/CodexBar/Providers/Antigravity/AntigravityProviderImplementation.swift b/Sources/CodexBar/Providers/Antigravity/AntigravityProviderImplementation.swift index fcac10ed1..c3b1ffc7a 100644 --- a/Sources/CodexBar/Providers/Antigravity/AntigravityProviderImplementation.swift +++ b/Sources/CodexBar/Providers/Antigravity/AntigravityProviderImplementation.swift @@ -1,10 +1,81 @@ import CodexBarCore import CodexBarMacroSupport import Foundation +import SwiftUI @ProviderImplementationRegistration struct AntigravityProviderImplementation: ProviderImplementation { let id: UsageProvider = .antigravity + let supportsLoginFlow: Bool = true + + @MainActor + func observeSettings(_ settings: SettingsStore) { + _ = settings.antigravityUsageDataSource + } + + @MainActor + func defaultSourceLabel(context: ProviderSourceLabelContext) -> String? { + context.settings.antigravityUsageDataSource.rawValue + } + + @MainActor + func sourceMode(context: ProviderSourceModeContext) -> ProviderSourceMode { + switch context.settings.antigravityUsageDataSource { + case .auto: .auto + case .oauth: .oauth + case .cli: .cli + } + } + + @MainActor + func settingsPickers(context: ProviderSettingsContext) -> [ProviderSettingsPickerDescriptor] { + let usageBinding = Binding( + get: { context.settings.antigravityUsageDataSource.rawValue }, + set: { raw in + context.settings.antigravityUsageDataSource = AntigravityUsageDataSource(rawValue: raw) ?? .auto + }) + let usageOptions = AntigravityUsageDataSource.allCases.map { + ProviderSettingsPickerOption(id: $0.rawValue, title: $0.displayName) + } + return [ + ProviderSettingsPickerDescriptor( + id: "antigravity-usage-source", + title: "Usage source", + subtitle: "Auto uses the local IDE API first, then Google OAuth when the IDE is closed.", + binding: usageBinding, + options: usageOptions, + isVisible: nil, + onChange: nil, + trailingText: { + guard context.settings.antigravityUsageDataSource == .auto else { return nil } + let label = context.store.sourceLabel(for: .antigravity) + return label == "auto" ? nil : label + }), + ] + } + + @MainActor + func settingsActions(context: ProviderSettingsContext) -> [ProviderSettingsActionsDescriptor] { + let credentialsPath = AntigravityOAuthCredentialsStore().fileURL.path + let loginTitle = FileManager.default.fileExists(atPath: credentialsPath) ? "Re-authenticate" : "Login with Google" + return [ + ProviderSettingsActionsDescriptor( + id: "antigravity-oauth", + title: "Google OAuth", + subtitle: "Stores credentials in ~/.codexbar/antigravity/oauth_creds.json. Uses Antigravity.app OAuth when available, or ANTIGRAVITY_OAUTH_CLIENT_ID and ANTIGRAVITY_OAUTH_CLIENT_SECRET as an override.", + actions: [ + ProviderSettingsActionDescriptor( + id: "antigravity-oauth-login", + title: loginTitle, + style: .bordered, + isVisible: nil, + perform: { + await context.runLoginFlow() + }), + ], + isVisible: nil), + ] + } func detectVersion(context _: ProviderVersionContext) async -> String? { await AntigravityStatusProbe.detectVersion() diff --git a/Sources/CodexBar/Providers/Antigravity/AntigravitySettingsStore.swift b/Sources/CodexBar/Providers/Antigravity/AntigravitySettingsStore.swift new file mode 100644 index 000000000..a8deaeaa9 --- /dev/null +++ b/Sources/CodexBar/Providers/Antigravity/AntigravitySettingsStore.swift @@ -0,0 +1,36 @@ +import CodexBarCore +import Foundation + +extension SettingsStore { + var antigravityUsageDataSource: AntigravityUsageDataSource { + get { + let source = self.configSnapshot.providerConfig(for: .antigravity)?.source + return Self.antigravityUsageDataSource(from: source) + } + set { + let source: ProviderSourceMode? = switch newValue { + case .auto: .auto + case .oauth: .oauth + case .cli: .cli + } + self.updateProviderConfig(provider: .antigravity) { entry in + entry.source = source + } + self.logProviderModeChange(provider: .antigravity, field: "usageSource", value: newValue.rawValue) + } + } +} + +extension SettingsStore { + private static func antigravityUsageDataSource(from source: ProviderSourceMode?) -> AntigravityUsageDataSource { + guard let source else { return .auto } + switch source { + case .auto, .web, .api: + return .auto + case .oauth: + return .oauth + case .cli: + return .cli + } + } +} diff --git a/Sources/CodexBar/Providers/Shared/ProviderImplementation.swift b/Sources/CodexBar/Providers/Shared/ProviderImplementation.swift index 7d5e22bd2..7340de2af 100644 --- a/Sources/CodexBar/Providers/Shared/ProviderImplementation.swift +++ b/Sources/CodexBar/Providers/Shared/ProviderImplementation.swift @@ -42,6 +42,10 @@ protocol ProviderImplementation: Sendable { @MainActor func settingsFields(context: ProviderSettingsContext) -> [ProviderSettingsFieldDescriptor] + /// Optional provider-specific settings action rows to render in the Providers pane. + @MainActor + func settingsActions(context: ProviderSettingsContext) -> [ProviderSettingsActionsDescriptor] + /// Optional provider-specific settings pickers to render in the Providers pane. @MainActor func settingsPickers(context: ProviderSettingsContext) -> [ProviderSettingsPickerDescriptor] @@ -129,6 +133,11 @@ extension ProviderImplementation { [] } + @MainActor + func settingsActions(context _: ProviderSettingsContext) -> [ProviderSettingsActionsDescriptor] { + [] + } + @MainActor func settingsPickers(context _: ProviderSettingsContext) -> [ProviderSettingsPickerDescriptor] { [] diff --git a/Sources/CodexBar/Providers/Shared/ProviderSettingsDescriptors.swift b/Sources/CodexBar/Providers/Shared/ProviderSettingsDescriptors.swift index d5a85b8f7..1acf75e2d 100644 --- a/Sources/CodexBar/Providers/Shared/ProviderSettingsDescriptors.swift +++ b/Sources/CodexBar/Providers/Shared/ProviderSettingsDescriptors.swift @@ -25,6 +25,7 @@ struct ProviderSettingsContext { let setLastAppActiveRunAt: (String, Date?) -> Void let requestConfirmation: (ProviderSettingsConfirmation) -> Void + let runLoginFlow: () async -> Void } /// Shared confirmation alert descriptor. @@ -84,6 +85,16 @@ struct ProviderSettingsFieldDescriptor: Identifiable { let onActivate: (() -> Void)? } +/// Shared action row descriptor rendered in the Providers settings pane. +@MainActor +struct ProviderSettingsActionsDescriptor: Identifiable { + let id: String + let title: String + let subtitle: String + let actions: [ProviderSettingsActionDescriptor] + let isVisible: (() -> Bool)? +} + /// Shared token account descriptor rendered in the Providers settings pane. @MainActor struct ProviderSettingsTokenAccountsDescriptor: Identifiable { diff --git a/Sources/CodexBar/SettingsStore+ProviderDetection.swift b/Sources/CodexBar/SettingsStore+ProviderDetection.swift index b8534276c..d4efded24 100644 --- a/Sources/CodexBar/SettingsStore+ProviderDetection.swift +++ b/Sources/CodexBar/SettingsStore+ProviderDetection.swift @@ -17,14 +17,17 @@ extension SettingsStore { let claudeInstalled = BinaryLocator.resolveClaudeBinary() != nil let geminiInstalled = BinaryLocator.resolveGeminiBinary() != nil let antigravityRunning = await AntigravityStatusProbe.isRunning() + let antigravityLoggedIn = FileManager.default.fileExists( + atPath: AntigravityOAuthCredentialsStore().fileURL.path) let logger = CodexBarLog.logger(LogCategories.providerDetection) // If none installed, keep Codex enabled to match previous behavior. - let noneInstalled = !codexInstalled && !claudeInstalled && !geminiInstalled && !antigravityRunning + let noneInstalled = !codexInstalled && !claudeInstalled && !geminiInstalled && !antigravityRunning && + !antigravityLoggedIn let enableCodex = codexInstalled || noneInstalled let enableClaude = claudeInstalled let enableGemini = geminiInstalled - let enableAntigravity = antigravityRunning + let enableAntigravity = antigravityRunning || antigravityLoggedIn logger.info( "Provider detection results", @@ -33,6 +36,7 @@ extension SettingsStore { "claudeInstalled": claudeInstalled ? "1" : "0", "geminiInstalled": geminiInstalled ? "1" : "0", "antigravityRunning": antigravityRunning ? "1" : "0", + "antigravityLoggedIn": antigravityLoggedIn ? "1" : "0", ]) logger.info( "Provider detection enablement", diff --git a/Sources/CodexBar/StatusItemController+Actions.swift b/Sources/CodexBar/StatusItemController+Actions.swift index 930e42c00..688d2e69f 100644 --- a/Sources/CodexBar/StatusItemController+Actions.swift +++ b/Sources/CodexBar/StatusItemController+Actions.swift @@ -141,25 +141,16 @@ extension StatusItemController { let rawProvider = sender.representedObject as? String let provider = rawProvider.flatMap(UsageProvider.init(rawValue:)) ?? self.lastMenuProvider ?? .codex self.loginLogger.info("Switch Account tapped", metadata: ["provider": provider.rawValue]) + self.startLoginFlow(provider: provider) + } - self.loginTask = Task { @MainActor [weak self] in - guard let self else { return } - defer { - self.activeLoginProvider = nil - self.loginTask = nil - } - self.activeLoginProvider = provider - self.loginPhase = .requesting - self.loginLogger.info("Starting login task", metadata: ["provider": provider.rawValue]) - - let shouldRefresh = await self.runLoginFlow(provider: provider) - if shouldRefresh { - await ProviderInteractionContext.$current.withValue(.userInitiated) { - await self.store.refresh() - } - self.loginLogger.info("Triggered refresh after login", metadata: ["provider": provider.rawValue]) - } + func runLoginFlowFromSettings(provider: UsageProvider) async { + guard self.loginTask == nil else { + self.loginLogger.info("Settings login tap ignored: login already in-flight", metadata: ["provider": provider.rawValue]) + return } + self.startLoginFlow(provider: provider) + await self.loginTask?.value } @objc func showSettingsGeneral() { @@ -236,6 +227,27 @@ extension StatusItemController { return .codex } + private func startLoginFlow(provider: UsageProvider) { + self.loginTask = Task { @MainActor [weak self] in + guard let self else { return } + defer { + self.activeLoginProvider = nil + self.loginTask = nil + } + self.activeLoginProvider = provider + self.loginPhase = .requesting + self.loginLogger.info("Starting login task", metadata: ["provider": provider.rawValue]) + + let shouldRefresh = await self.runLoginFlow(provider: provider) + if shouldRefresh { + await ProviderInteractionContext.$current.withValue(.userInitiated) { + await self.store.refresh() + } + self.loginLogger.info("Triggered refresh after login", metadata: ["provider": provider.rawValue]) + } + } + } + func presentCodexLoginResult(_ result: CodexLoginRunner.Result) { guard let info = CodexLoginAlertPresentation.alertInfo(for: result) else { return } self.presentLoginAlert(title: info.title, message: info.message) @@ -316,11 +328,31 @@ extension StatusItemController { } } + func describe(_ outcome: AntigravityLoginRunner.Result.Outcome) -> String { + switch outcome { + case let .success(email): + "success(email: \(email ?? "nil"))" + case .cancelled: + "cancelled" + case .timedOut: + "timedOut" + case let .launchFailed(message): + "launchFailed(\(message))" + case let .failed(message): + "failed(\(message))" + } + } + func presentGeminiLoginResult(_ result: GeminiLoginRunner.Result) { guard let info = Self.geminiLoginAlertInfo(for: result) else { return } self.presentLoginAlert(title: info.title, message: info.message) } + func presentAntigravityLoginResult(_ result: AntigravityLoginRunner.Result) { + guard let info = Self.antigravityLoginAlertInfo(for: result) else { return } + self.presentLoginAlert(title: info.title, message: info.message) + } + struct LoginAlertInfo: Equatable { let title: String let message: String @@ -339,6 +371,23 @@ extension StatusItemController { } } + nonisolated static func antigravityLoginAlertInfo(for result: AntigravityLoginRunner.Result) -> LoginAlertInfo? { + switch result.outcome { + case .success, .cancelled: + nil + case .timedOut: + LoginAlertInfo( + title: "Antigravity login timed out", + message: "The browser login did not complete in time. Try Antigravity login again.") + case let .launchFailed(message): + LoginAlertInfo( + title: "Could not open browser for Antigravity", + message: "Open this URL manually to continue login:\n\n\(message)") + case let .failed(message): + LoginAlertInfo(title: "Antigravity login failed", message: message) + } + } + func presentLoginAlert(title: String, message: String) { let alert = NSAlert() alert.messageText = title diff --git a/Sources/CodexBar/StatusItemController.swift b/Sources/CodexBar/StatusItemController.swift index 26330607f..fd8cca0d6 100644 --- a/Sources/CodexBar/StatusItemController.swift +++ b/Sources/CodexBar/StatusItemController.swift @@ -9,6 +9,7 @@ import SwiftUI @MainActor protocol StatusItemControlling: AnyObject { func openMenuFromShortcut() + func runLoginFlowFromSettings(provider: UsageProvider) async } @MainActor diff --git a/Sources/CodexBarCore/Providers/Antigravity/AntigravityOAuthCredentialsStore.swift b/Sources/CodexBarCore/Providers/Antigravity/AntigravityOAuthCredentialsStore.swift new file mode 100644 index 000000000..e4d7addd3 --- /dev/null +++ b/Sources/CodexBarCore/Providers/Antigravity/AntigravityOAuthCredentialsStore.swift @@ -0,0 +1,267 @@ +import Foundation + +public struct AntigravityOAuthCredentials: Codable, Sendable, Equatable { + public var accessToken: String? + public var refreshToken: String? + public var expiryDateMilliseconds: Double? + public var idToken: String? + public var email: String? + public var projectID: String? + public var clientID: String? + public var clientSecret: String? + + public init( + accessToken: String?, + refreshToken: String?, + expiryDate: Date?, + idToken: String? = nil, + email: String? = nil, + projectID: String? = nil, + clientID: String? = nil, + clientSecret: String? = nil) + { + self.accessToken = accessToken + self.refreshToken = refreshToken + self.expiryDateMilliseconds = expiryDate.map { $0.timeIntervalSince1970 * 1000 } + self.idToken = idToken + self.email = email + self.projectID = projectID + self.clientID = clientID + self.clientSecret = clientSecret + } + + public var expiryDate: Date? { + guard let expiryDateMilliseconds else { return nil } + return Date(timeIntervalSince1970: expiryDateMilliseconds / 1000) + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.accessToken = + try container.decodeIfPresent(String.self, forKey: .accessTokenSnake) + ?? container.decodeIfPresent(String.self, forKey: .accessTokenCamel) + self.refreshToken = + try container.decodeIfPresent(String.self, forKey: .refreshTokenSnake) + ?? container.decodeIfPresent(String.self, forKey: .refreshTokenCamel) + self.idToken = + try container.decodeIfPresent(String.self, forKey: .idTokenSnake) + ?? container.decodeIfPresent(String.self, forKey: .idTokenCamel) + self.email = try container.decodeIfPresent(String.self, forKey: .email) + self.projectID = + try container.decodeIfPresent(String.self, forKey: .projectIDSnake) + ?? container.decodeIfPresent(String.self, forKey: .projectIDCamel) + self.clientID = + try container.decodeIfPresent(String.self, forKey: .clientIDSnake) + ?? container.decodeIfPresent(String.self, forKey: .clientIDCamel) + self.clientSecret = + try container.decodeIfPresent(String.self, forKey: .clientSecretSnake) + ?? container.decodeIfPresent(String.self, forKey: .clientSecretCamel) + + if let expiryDateMilliseconds = try container.decodeIfPresent(Double.self, forKey: .expiryDateSnake) + ?? container.decodeIfPresent(Double.self, forKey: .expiresAtCamel) + { + self.expiryDateMilliseconds = expiryDateMilliseconds + } else if let expiryDateMilliseconds = try container.decodeIfPresent(Int.self, forKey: .expiryDateSnake) + ?? container.decodeIfPresent(Int.self, forKey: .expiresAtCamel) + { + self.expiryDateMilliseconds = Double(expiryDateMilliseconds) + } else { + self.expiryDateMilliseconds = nil + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encodeIfPresent(self.accessToken, forKey: .accessTokenSnake) + try container.encodeIfPresent(self.refreshToken, forKey: .refreshTokenSnake) + try container.encodeIfPresent(self.expiryDateMilliseconds, forKey: .expiryDateSnake) + try container.encodeIfPresent(self.idToken, forKey: .idTokenSnake) + try container.encodeIfPresent(self.email, forKey: .email) + try container.encodeIfPresent(self.projectID, forKey: .projectIDSnake) + try container.encodeIfPresent(self.clientID, forKey: .clientIDSnake) + try container.encodeIfPresent(self.clientSecret, forKey: .clientSecretSnake) + } + + enum CodingKeys: String, CodingKey { + case accessTokenSnake = "access_token" + case accessTokenCamel = "accessToken" + case refreshTokenSnake = "refresh_token" + case refreshTokenCamel = "refreshToken" + case expiryDateSnake = "expiry_date" + case expiresAtCamel = "expiresAt" + case idTokenSnake = "id_token" + case idTokenCamel = "idToken" + case email + case projectIDSnake = "project_id" + case projectIDCamel = "projectId" + case clientIDSnake = "client_id" + case clientIDCamel = "clientId" + case clientSecretSnake = "client_secret" + case clientSecretCamel = "clientSecret" + } +} + +public struct AntigravityOAuthClient: Sendable, Equatable { + public let clientID: String + public let clientSecret: String + + public init(clientID: String, clientSecret: String) { + self.clientID = clientID + self.clientSecret = clientSecret + } +} + +public enum AntigravityOAuthConfig { + public static var configuredClientID: String? { + let value = ProcessInfo.processInfo.environment["ANTIGRAVITY_OAUTH_CLIENT_ID"] + return value?.trimmingCharacters(in: .whitespacesAndNewlines).nilIfEmpty + } + + public static var configuredClientSecret: String? { + let value = ProcessInfo.processInfo.environment["ANTIGRAVITY_OAUTH_CLIENT_SECRET"] + return value?.trimmingCharacters(in: .whitespacesAndNewlines).nilIfEmpty + } + + public static let authURL = URL(string: "https://accounts.google.com/o/oauth2/v2/auth")! + public static let tokenURL = URL(string: "https://oauth2.googleapis.com/token")! + public static let userInfoURL = URL(string: "https://www.googleapis.com/oauth2/v2/userinfo")! + public static let scopes = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + ] + + public static let missingCredentialsMessage = + "Antigravity OAuth client is not configured. Install Antigravity.app or set ANTIGRAVITY_OAUTH_CLIENT_ID and ANTIGRAVITY_OAUTH_CLIENT_SECRET before logging in." + + public static func resolvedClient() -> AntigravityOAuthClient? { + if let client = Self.environmentClient() { + return client + } + return Self.discoverClientFromInstalledApp() + } + + private static func environmentClient() -> AntigravityOAuthClient? { + guard let clientID = Self.configuredClientID, + let clientSecret = Self.configuredClientSecret + else { + return nil + } + return AntigravityOAuthClient(clientID: clientID, clientSecret: clientSecret) + } + + private static func discoverClientFromInstalledApp(fileManager: FileManager = .default) -> AntigravityOAuthClient? { + for url in Self.candidateAppMainJSURLs(fileManager: fileManager) where fileManager.fileExists(atPath: url.path) { + guard let content = try? String(contentsOf: url, encoding: .utf8), + let client = Self.parseClient(fromMainJS: content) + else { + continue + } + return client + } + return nil + } + + private static func candidateAppMainJSURLs(fileManager: FileManager) -> [URL] { + let bundleRelativePath = "Antigravity.app/Contents/Resources/app/out/main.js" + return [ + URL(fileURLWithPath: "/Applications", isDirectory: true).appendingPathComponent(bundleRelativePath), + fileManager.homeDirectoryForCurrentUser + .appendingPathComponent("Applications", isDirectory: true) + .appendingPathComponent(bundleRelativePath), + ] + } + + private static func parseClient(fromMainJS content: String) -> AntigravityOAuthClient? { + let marker = "vs/platform/cloudCode/common/oauthClient.js" + let searchStart = content.range(of: marker)?.lowerBound ?? content.startIndex + let searchEnd = content.index(searchStart, offsetBy: 4_000, limitedBy: content.endIndex) ?? content.endIndex + let haystack = String(content[searchStart.. String? { + guard let regex = try? NSRegularExpression(pattern: pattern) else { return nil } + let range = NSRange(text.startIndex.. AntigravityOAuthCredentials? { + guard self.fileManager.fileExists(atPath: self.fileURL.path) else { return nil } + let data = try Data(contentsOf: self.fileURL) + return try JSONDecoder().decode(AntigravityOAuthCredentials.self, from: data) + } + + public func save(_ credentials: AntigravityOAuthCredentials) throws { + let data = try JSONEncoder.antigravityCredentials.encode(credentials) + let directory = self.fileURL.deletingLastPathComponent() + if !self.fileManager.fileExists(atPath: directory.path) { + try self.fileManager.createDirectory(at: directory, withIntermediateDirectories: true) + } + try data.write(to: self.fileURL, options: [.atomic]) + try self.applySecurePermissionsIfNeeded() + } + + public func deleteIfPresent() throws { + guard self.fileManager.fileExists(atPath: self.fileURL.path) else { return } + try self.fileManager.removeItem(at: self.fileURL) + } + + public static func defaultDirectoryURL(home: URL = FileManager.default.homeDirectoryForCurrentUser) -> URL { + home + .appendingPathComponent(".codexbar", isDirectory: true) + .appendingPathComponent("antigravity", isDirectory: true) + } + + public static func defaultURL(home: URL = FileManager.default.homeDirectoryForCurrentUser) -> URL { + Self.defaultDirectoryURL(home: home) + .appendingPathComponent("oauth_creds.json") + } + + private func applySecurePermissionsIfNeeded() throws { + #if os(macOS) || os(Linux) + try self.fileManager.setAttributes([ + .posixPermissions: NSNumber(value: Int16(0o600)), + ], ofItemAtPath: self.fileURL.path) + #endif + } +} + +private extension JSONEncoder { + static let antigravityCredentials: JSONEncoder = { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + return encoder + }() +} + +private extension String { + var nilIfEmpty: String? { + self.isEmpty ? nil : self + } +} diff --git a/Sources/CodexBarCore/Providers/Antigravity/AntigravityProviderDescriptor.swift b/Sources/CodexBarCore/Providers/Antigravity/AntigravityProviderDescriptor.swift index 1e59964b0..796ead0cd 100644 --- a/Sources/CodexBarCore/Providers/Antigravity/AntigravityProviderDescriptor.swift +++ b/Sources/CodexBarCore/Providers/Antigravity/AntigravityProviderDescriptor.swift @@ -33,12 +33,28 @@ public enum AntigravityProviderDescriptor { supportsTokenCost: false, noDataMessage: { "Antigravity cost summary is not supported." }), fetchPlan: ProviderFetchPlan( - sourceModes: [.auto, .cli], - pipeline: ProviderFetchPipeline(resolveStrategies: { _ in [AntigravityStatusFetchStrategy()] })), + sourceModes: [.auto, .cli, .oauth], + pipeline: ProviderFetchPipeline(resolveStrategies: self.resolveStrategies)), cli: ProviderCLIConfig( name: "antigravity", versionDetector: nil)) } + + private static func resolveStrategies(context: ProviderFetchContext) async -> [any ProviderFetchStrategy] { + let local = AntigravityStatusFetchStrategy() + let oauth = AntigravityOAuthFetchStrategy() + + switch context.sourceMode { + case .cli: + return [local] + case .oauth: + return [oauth] + case .auto: + return [local, oauth] + case .web, .api: + return [] + } + } } struct AntigravityStatusFetchStrategy: ProviderFetchStrategy { @@ -58,6 +74,41 @@ struct AntigravityStatusFetchStrategy: ProviderFetchStrategy { sourceLabel: "local") } + func shouldFallback(on _: Error, context: ProviderFetchContext) -> Bool { + context.sourceMode == .auto + } +} + +struct AntigravityOAuthFetchStrategy: ProviderFetchStrategy { + let id: String = "antigravity.oauth" + let kind: ProviderFetchKind = .oauth + + func isAvailable(_: ProviderFetchContext) async -> Bool { + true + } + + func fetch(_: ProviderFetchContext) async throws -> ProviderFetchResult { + let fetcher = AntigravityRemoteUsageFetcher() + let snapshot = try await fetcher.fetch() + let usage = if snapshot.modelQuotas.isEmpty { + UsageSnapshot( + primary: nil, + secondary: nil, + tertiary: nil, + updatedAt: Date(), + identity: ProviderIdentitySnapshot( + providerID: .antigravity, + accountEmail: snapshot.accountEmail, + accountOrganization: nil, + loginMethod: snapshot.accountPlan)) + } else { + try snapshot.toUsageSnapshot() + } + return self.makeResult( + usage: usage, + sourceLabel: "oauth") + } + func shouldFallback(on _: Error, context _: ProviderFetchContext) -> Bool { false } diff --git a/Sources/CodexBarCore/Providers/Antigravity/AntigravityRemoteUsageFetcher.swift b/Sources/CodexBarCore/Providers/Antigravity/AntigravityRemoteUsageFetcher.swift new file mode 100644 index 000000000..cd1a37718 --- /dev/null +++ b/Sources/CodexBarCore/Providers/Antigravity/AntigravityRemoteUsageFetcher.swift @@ -0,0 +1,660 @@ +import Foundation +#if canImport(FoundationNetworking) +import FoundationNetworking +#endif + +public enum AntigravityRemoteFetchError: LocalizedError, Sendable, Equatable { + case notLoggedIn + case permissionDenied(String) + case apiError(String) + case parseFailed(String) + + public var errorDescription: String? { + switch self { + case .notLoggedIn: + "Antigravity Google auth not found. Use Antigravity login to authenticate." + case let .permissionDenied(message): + "Antigravity remote API permission denied: \(message)" + case let .apiError(message): + "Antigravity remote API error: \(message)" + case let .parseFailed(message): + "Could not parse Antigravity remote usage: \(message)" + } + } +} + +public struct AntigravityRemoteUsageFetcher: Sendable { + public var timeout: TimeInterval = 10.0 + public var homeDirectory: String + public var dataLoader: @Sendable (URLRequest) async throws -> (Data, URLResponse) + + private static let log = CodexBarLog.logger(LogCategories.antigravity) + private static let userAgent = "antigravity" + private static let baseURL = "https://cloudcode-pa.googleapis.com" + private static let loadCodeAssistEndpoint = "\(baseURL)/v1internal:loadCodeAssist" + private static let onboardUserEndpoint = "\(baseURL)/v1internal:onboardUser" + private static let fetchAvailableModelsEndpoint = "\(baseURL)/v1internal:fetchAvailableModels" + private static let retrieveUserQuotaEndpoint = "\(baseURL)/v1internal:retrieveUserQuota" + + public init( + timeout: TimeInterval = 10.0, + homeDirectory: String = NSHomeDirectory(), + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse) = { request in + try await URLSession.shared.data(for: request) + }) + { + self.timeout = timeout + self.homeDirectory = homeDirectory + self.dataLoader = dataLoader + } + + public func fetch() async throws -> AntigravityStatusSnapshot { + let source = try Self.resolveCredentialSource(homeDirectory: self.homeDirectory) + let store = source.primaryStore + guard let credentials = source.credentials else { + throw AntigravityRemoteFetchError.notLoggedIn + } + return try await Self.fetchSnapshot( + using: credentials, + timeout: self.timeout, + store: store, + dataLoader: self.dataLoader) + } + + private static func fetchSnapshot( + using initialCredentials: AntigravityOAuthCredentials, + timeout: TimeInterval, + store: AntigravityOAuthCredentialsStore, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> AntigravityStatusSnapshot + { + guard let storedAccessToken = initialCredentials.accessToken?.trimmedNonEmpty else { + throw AntigravityRemoteFetchError.notLoggedIn + } + + var credentials = initialCredentials + var accessToken = storedAccessToken + if let expiryDate = credentials.expiryDate, expiryDate < Date() { + guard let refreshToken = credentials.refreshToken?.trimmedNonEmpty else { + throw AntigravityRemoteFetchError.notLoggedIn + } + accessToken = try await Self.refreshAccessToken( + credentials: credentials, + refreshToken: refreshToken, + timeout: timeout, + store: store, + dataLoader: dataLoader) + credentials = try store.load() ?? credentials + credentials.accessToken = credentials.accessToken?.trimmedNonEmpty ?? accessToken + } + + let claims = Self.extractClaims(from: credentials) + let codeAssist = try await Self.loadCodeAssist( + accessToken: accessToken, + timeout: timeout, + dataLoader: dataLoader) + let projectId = try await Self.resolveProjectID( + accessToken: accessToken, + storedProjectID: credentials.projectID?.trimmedNonEmpty, + initialResponse: codeAssist, + timeout: timeout, + store: store, + dataLoader: dataLoader) + let models = try await Self.fetchModelQuotas( + accessToken: accessToken, + projectId: projectId, + timeout: timeout, + dataLoader: dataLoader) + + return AntigravityStatusSnapshot( + modelQuotas: models, + accountEmail: claims.email, + accountPlan: Self.resolvePlan(response: codeAssist, claims: claims)) + } + + private static func loadCodeAssist( + accessToken: String, + timeout: TimeInterval, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> CodeAssistResponse + { + let body = [ + "metadata": [ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + ], + ] + return try await Self.sendRequest( + endpoint: Self.loadCodeAssistEndpoint, + accessToken: accessToken, + body: body, + timeout: timeout, + dataLoader: dataLoader) + } + + private static func fetchAvailableModels( + accessToken: String, + projectId: String?, + timeout: TimeInterval, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> FetchAvailableModelsResponse + { + let body: [String: Any] = if let projectId = projectId?.trimmedNonEmpty { + ["project": projectId] + } else { + [:] + } + return try await Self.sendRequest( + endpoint: Self.fetchAvailableModelsEndpoint, + accessToken: accessToken, + body: body, + timeout: timeout, + dataLoader: dataLoader) + } + + private static func fetchModelQuotas( + accessToken: String, + projectId: String?, + timeout: TimeInterval, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> [AntigravityModelQuota] + { + do { + let response = try await Self.fetchAvailableModels( + accessToken: accessToken, + projectId: projectId, + timeout: timeout, + dataLoader: dataLoader) + return try Self.parseModelQuotas(response) + } catch let error as AntigravityRemoteFetchError { + guard case .permissionDenied = error else { + throw error + } + Self.log.info("Falling back to retrieveUserQuota for Antigravity remote usage") + do { + let response = try await Self.retrieveUserQuota( + accessToken: accessToken, + projectId: projectId, + timeout: timeout, + dataLoader: dataLoader) + return try Self.parseQuotaBuckets(response) + } catch let quotaError as AntigravityRemoteFetchError { + guard case .permissionDenied = quotaError else { + throw quotaError + } + Self.log.info("Antigravity remote quota endpoints are not permitted for this account") + return [] + } + } + } + + private static func retrieveUserQuota( + accessToken: String, + projectId: String?, + timeout: TimeInterval, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> RetrieveUserQuotaResponse + { + let body: [String: Any] = if let projectId = projectId?.trimmedNonEmpty { + ["project": projectId] + } else { + [:] + } + return try await Self.sendRequest( + endpoint: Self.retrieveUserQuotaEndpoint, + accessToken: accessToken, + body: body, + timeout: timeout, + dataLoader: dataLoader) + } + + private static func resolveProjectID( + accessToken: String, + storedProjectID: String?, + initialResponse: CodeAssistResponse, + timeout: TimeInterval, + store: AntigravityOAuthCredentialsStore, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> String? + { + if let storedProjectID { + return storedProjectID + } + + if let projectID = initialResponse.projectID { + try? Self.updateStoredProjectID(projectID, store: store) + return projectID + } + + guard let tierID = Self.pickOnboardTier(from: initialResponse) else { + return nil + } + + let onboardBody: [String: Any] = [ + "tierId": tierID, + "metadata": [ + "ideType": "ANTIGRAVITY", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + ], + ] + + do { + let onboardResponse: OnboardResponse = try await Self.sendRequest( + endpoint: Self.onboardUserEndpoint, + accessToken: accessToken, + body: onboardBody, + timeout: timeout, + dataLoader: dataLoader) + if let projectID = onboardResponse.projectID { + try? Self.updateStoredProjectID(projectID, store: store) + return projectID + } + } catch { + Self.log.warning("Antigravity onboarding request failed", metadata: [ + "error": "\(error.localizedDescription)", + ]) + } + + for _ in 0 ..< 5 { + try? await Task.sleep(for: .milliseconds(2000)) + let refreshed = try await Self.loadCodeAssist( + accessToken: accessToken, + timeout: timeout, + dataLoader: dataLoader) + if let projectID = refreshed.projectID { + try? Self.updateStoredProjectID(projectID, store: store) + return projectID + } + } + + return nil + } + + private static func sendRequest( + endpoint: String, + accessToken: String, + body: [String: Any], + timeout: TimeInterval, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> Response + { + guard let url = URL(string: endpoint) else { + throw AntigravityRemoteFetchError.apiError("Invalid endpoint URL") + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.timeoutInterval = timeout + request.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization") + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue(Self.userAgent, forHTTPHeaderField: "User-Agent") + request.httpBody = try JSONSerialization.data(withJSONObject: body) + + let (data, response) = try await dataLoader(request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AntigravityRemoteFetchError.apiError("Invalid response") + } + + switch httpResponse.statusCode { + case 200: + break + case 401: + throw AntigravityRemoteFetchError.notLoggedIn + case 403: + let message = String(data: data, encoding: .utf8)?.trimmedNonEmpty ?? "HTTP 403" + throw AntigravityRemoteFetchError.permissionDenied(message) + default: + let message = String(data: data, encoding: .utf8)?.trimmedNonEmpty ?? "HTTP \(httpResponse.statusCode)" + throw AntigravityRemoteFetchError.apiError("HTTP \(httpResponse.statusCode): \(message)") + } + + do { + return try JSONDecoder().decode(Response.self, from: data) + } catch { + throw AntigravityRemoteFetchError.parseFailed(error.localizedDescription) + } + } + + private static func parseModelQuotas(_ response: FetchAvailableModelsResponse) throws -> [AntigravityModelQuota] { + let models = response.models ?? [:] + return models.compactMap { modelID, model in + guard let quotaInfo = model.quotaInfo else { return nil } + let resetTime = quotaInfo.resetTime.flatMap(Self.parseResetTime(_:)) + let label = model.displayName?.trimmedNonEmpty + ?? model.label?.trimmedNonEmpty + ?? modelID + return AntigravityModelQuota( + label: label, + modelId: modelID, + remainingFraction: quotaInfo.remainingFraction, + resetTime: resetTime, + resetDescription: resetTime.map { UsageFormatter.resetDescription(from: $0) }) + } + } + + private static func parseQuotaBuckets(_ response: RetrieveUserQuotaResponse) throws -> [AntigravityModelQuota] { + guard let buckets = response.buckets, !buckets.isEmpty else { + throw AntigravityRemoteFetchError.parseFailed("No quota buckets in response") + } + + var modelQuotaMap: [String: (fraction: Double?, resetTime: String?)] = [:] + for bucket in buckets { + guard let modelID = bucket.modelId?.trimmedNonEmpty else { continue } + let next = (bucket.remainingFraction, bucket.resetTime) + if let existing = modelQuotaMap[modelID] { + let existingValue = existing.fraction ?? Double.greatestFiniteMagnitude + let nextValue = next.0 ?? Double.greatestFiniteMagnitude + if nextValue < existingValue { + modelQuotaMap[modelID] = next + } + } else { + modelQuotaMap[modelID] = next + } + } + + return modelQuotaMap.keys.sorted().compactMap { modelID in + guard let info = modelQuotaMap[modelID] else { return nil } + let resetTime = info.resetTime.flatMap(Self.parseResetTime(_:)) + return AntigravityModelQuota( + label: modelID, + modelId: modelID, + remainingFraction: info.fraction, + resetTime: resetTime, + resetDescription: resetTime.map { UsageFormatter.resetDescription(from: $0) }) + } + } + + private static func resolvePlan(response: CodeAssistResponse, claims: TokenClaims) -> String? { + if let planType = response.planInfo?.planType?.trimmedNonEmpty { + return planType + } + + switch (response.currentTier?.id?.trimmedNonEmpty, claims.hostedDomain) { + case ("standard-tier", _): + return "Paid" + case ("free-tier", .some): + return "Workspace" + case ("free-tier", .none): + return "Free" + case ("legacy-tier", _): + return "Legacy" + default: + return response.currentTier?.name?.trimmedNonEmpty + } + } + + private static func pickOnboardTier(from response: CodeAssistResponse) -> String? { + if let defaultTier = response.allowedTiers? + .first(where: { $0.isDefault == true && $0.id?.trimmedNonEmpty != nil })?.id?.trimmedNonEmpty + { + return defaultTier + } + if let firstTier = response.allowedTiers? + .first(where: { $0.id?.trimmedNonEmpty != nil })?.id?.trimmedNonEmpty + { + return firstTier + } + if let paidTier = response.paidTier?.id?.trimmedNonEmpty { + return paidTier + } + if let currentTier = response.currentTier?.id?.trimmedNonEmpty { + return currentTier + } + return nil + } + + private static func parseResetTime(_ value: String) -> Date? { + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + if let date = formatter.date(from: value) { + return date + } + formatter.formatOptions = [.withInternetDateTime] + return formatter.date(from: value) + } + + private static func credentialsStore(homeDirectory: String) -> AntigravityOAuthCredentialsStore { + let homeURL = URL(fileURLWithPath: homeDirectory, isDirectory: true) + return AntigravityOAuthCredentialsStore(fileURL: AntigravityOAuthCredentialsStore.defaultURL(home: homeURL)) + } + + private static func resolveCredentialSource(homeDirectory: String) throws -> ( + credentials: AntigravityOAuthCredentials?, + primaryStore: AntigravityOAuthCredentialsStore) + { + let primaryStore = Self.credentialsStore(homeDirectory: homeDirectory) + return (try primaryStore.load(), primaryStore) + } + + private static func refreshAccessToken( + credentials: AntigravityOAuthCredentials, + refreshToken: String, + timeout: TimeInterval, + store: AntigravityOAuthCredentialsStore, + dataLoader: @escaping @Sendable (URLRequest) async throws -> (Data, URLResponse)) async throws + -> String + { + let oauthClient = try Self.refreshOAuthClient(from: credentials) + + var request = URLRequest(url: AntigravityOAuthConfig.tokenURL) + request.httpMethod = "POST" + request.timeoutInterval = timeout + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + request.httpBody = Self.formBody([ + "client_id": oauthClient.clientID, + "client_secret": oauthClient.clientSecret, + "refresh_token": refreshToken, + "grant_type": "refresh_token", + ]) + + let (data, response) = try await dataLoader(request) + guard let httpResponse = response as? HTTPURLResponse else { + throw AntigravityRemoteFetchError.apiError("Invalid refresh response") + } + guard httpResponse.statusCode == 200 else { + throw AntigravityRemoteFetchError.notLoggedIn + } + guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], + let accessToken = json["access_token"] as? String + else { + throw AntigravityRemoteFetchError.parseFailed("Could not parse refresh response") + } + + try Self.updateStoredCredentials(json, store: store) + return accessToken + } + + private static func refreshOAuthClient(from credentials: AntigravityOAuthCredentials) throws -> AntigravityOAuthClient { + if let clientID = credentials.clientID?.trimmedNonEmpty, + let clientSecret = credentials.clientSecret?.trimmedNonEmpty + { + return AntigravityOAuthClient(clientID: clientID, clientSecret: clientSecret) + } + + guard let client = AntigravityOAuthConfig.resolvedClient() else { + throw AntigravityRemoteFetchError.apiError(AntigravityOAuthConfig.missingCredentialsMessage) + } + return client + } + + private static func updateStoredCredentials( + _ refreshResponse: [String: Any], + store: AntigravityOAuthCredentialsStore) throws + { + guard var credentials = try store.load() else { return } + if let accessToken = refreshResponse["access_token"] as? String { + credentials.accessToken = accessToken + } + if let expiresIn = refreshResponse["expires_in"] as? Double { + credentials.expiryDateMilliseconds = (Date().timeIntervalSince1970 + expiresIn) * 1000 + } + if let expiresIn = refreshResponse["expires_in"] as? Int { + credentials.expiryDateMilliseconds = (Date().timeIntervalSince1970 + Double(expiresIn)) * 1000 + } + if let idToken = refreshResponse["id_token"] as? String { + credentials.idToken = idToken + } + try store.save(credentials) + } + + private static func updateStoredProjectID(_ projectID: String, store: AntigravityOAuthCredentialsStore) throws { + guard var credentials = try store.load() else { return } + guard credentials.projectID?.trimmedNonEmpty != projectID else { return } + credentials.projectID = projectID + try store.save(credentials) + } + + private static func formBody(_ values: [String: String]) -> Data? { + var components = URLComponents() + components.queryItems = values.map { key, value in + URLQueryItem(name: key, value: value) + } + return components.query?.data(using: .utf8) + } + + private struct TokenClaims { + let email: String? + let hostedDomain: String? + } + + private static func extractClaims(from credentials: AntigravityOAuthCredentials) -> TokenClaims { + let tokenClaims = Self.extractClaimsFromToken(credentials.idToken) + return TokenClaims( + email: tokenClaims.email ?? credentials.email?.trimmedNonEmpty, + hostedDomain: tokenClaims.hostedDomain) + } + + private static func extractClaimsFromToken(_ idToken: String?) -> TokenClaims { + guard let idToken else { + return TokenClaims(email: nil, hostedDomain: nil) + } + + let parts = idToken.components(separatedBy: ".") + guard parts.count >= 2 else { + return TokenClaims(email: nil, hostedDomain: nil) + } + + var payload = parts[1] + .replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let remainder = payload.count % 4 + if remainder > 0 { + payload += String(repeating: "=", count: 4 - remainder) + } + + guard let data = Data(base64Encoded: payload, options: .ignoreUnknownCharacters), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return TokenClaims(email: nil, hostedDomain: nil) + } + + return TokenClaims( + email: (json["email"] as? String)?.trimmedNonEmpty, + hostedDomain: (json["hd"] as? String)?.trimmedNonEmpty) + } +} + +private extension String { + var trimmedNonEmpty: String? { + let trimmed = self.trimmingCharacters(in: .whitespacesAndNewlines) + return trimmed.isEmpty ? nil : trimmed + } +} + +private struct ProjectReference: Decodable { + let value: String? + + init(from decoder: Decoder) throws { + let single = try decoder.singleValueContainer() + if let stringValue = try? single.decode(String.self) { + self.value = stringValue + return + } + + let keyed = try decoder.container(keyedBy: CodingKeys.self) + self.value = try keyed.decodeIfPresent(String.self, forKey: .id) + ?? keyed.decodeIfPresent(String.self, forKey: .projectID) + } + + private enum CodingKeys: String, CodingKey { + case id + case projectID = "projectId" + } +} + +private struct CodeAssistResponse: Decodable { + let planInfo: CodeAssistPlanInfo? + let currentTier: TierInfo? + let paidTier: TierInfo? + let allowedTiers: [AllowedTier]? + let cloudaicompanionProject: ProjectReference? + + var projectID: String? { + self.cloudaicompanionProject?.value?.trimmingCharacters(in: .whitespacesAndNewlines).nilIfEmpty + } +} + +private struct CodeAssistPlanInfo: Decodable { + let planType: String? +} + +private struct TierInfo: Decodable { + let id: String? + let name: String? +} + +private struct AllowedTier: Decodable { + let id: String? + let isDefault: Bool? +} + +private struct OnboardResponse: Decodable { + let response: OnboardInnerResponse? + + var projectID: String? { + self.response?.cloudaicompanionProject?.value?.trimmingCharacters(in: .whitespacesAndNewlines).nilIfEmpty + } +} + +private struct OnboardInnerResponse: Decodable { + let cloudaicompanionProject: ProjectReference? +} + +private struct FetchAvailableModelsResponse: Decodable { + let models: [String: AntigravityRemoteModel]? +} + +private struct RetrieveUserQuotaResponse: Decodable { + let buckets: [RetrieveUserQuotaBucket]? +} + +private struct RetrieveUserQuotaBucket: Decodable { + let modelId: String? + let remainingFraction: Double? + let resetTime: String? +} + +private struct AntigravityRemoteModel: Decodable { + let displayName: String? + let label: String? + let quotaInfo: AntigravityRemoteQuotaInfo? +} + +private struct AntigravityRemoteQuotaInfo: Decodable { + let remainingFraction: Double? + let resetTime: String? +} + +private extension Optional where Wrapped == String { + var trimmedNonEmpty: String? { + self?.trimmingCharacters(in: .whitespacesAndNewlines).nilIfEmpty + } +} + +private extension String { + var nilIfEmpty: String? { + self.isEmpty ? nil : self + } +} diff --git a/Sources/CodexBarCore/Providers/Antigravity/AntigravityStatusProbe.swift b/Sources/CodexBarCore/Providers/Antigravity/AntigravityStatusProbe.swift index 6ad656a3d..5192a6e5f 100644 --- a/Sources/CodexBarCore/Providers/Antigravity/AntigravityStatusProbe.swift +++ b/Sources/CodexBarCore/Providers/Antigravity/AntigravityStatusProbe.swift @@ -145,17 +145,29 @@ public struct AntigravityStatusSnapshot: Sendable { let candidates = models.filter { $0.family == family && $0.selectionPriority != nil } guard !candidates.isEmpty else { return nil } return candidates.min { lhs, rhs in + let lhsHasRemainingFraction = lhs.quota.remainingFraction != nil + let rhsHasRemainingFraction = rhs.quota.remainingFraction != nil + if lhsHasRemainingFraction != rhsHasRemainingFraction { + return lhsHasRemainingFraction && !rhsHasRemainingFraction + } let lhsPriority = lhs.selectionPriority ?? Int.max let rhsPriority = rhs.selectionPriority ?? Int.max if lhsPriority != rhsPriority { return lhsPriority < rhsPriority } - let lhsHasRemainingFraction = lhs.quota.remainingFraction != nil - let rhsHasRemainingFraction = rhs.quota.remainingFraction != nil - if lhsHasRemainingFraction != rhsHasRemainingFraction { - return lhsHasRemainingFraction && !rhsHasRemainingFraction + if lhs.quota.remainingPercent != rhs.quota.remainingPercent { + return lhs.quota.remainingPercent < rhs.quota.remainingPercent + } + switch (lhs.quota.resetTime, rhs.quota.resetTime) { + case let (.some(left), .some(right)) where left != right: + return left < right + case (.some, .none): + return true + case (.none, .some): + return false + default: + return lhs.quota.label.localizedCaseInsensitiveCompare(rhs.quota.label) == .orderedAscending } - return lhs.quota.remainingPercent < rhs.quota.remainingPercent }?.quota } diff --git a/Sources/CodexBarCore/Providers/Antigravity/AntigravityUsageDataSource.swift b/Sources/CodexBarCore/Providers/Antigravity/AntigravityUsageDataSource.swift new file mode 100644 index 000000000..7db69bf62 --- /dev/null +++ b/Sources/CodexBarCore/Providers/Antigravity/AntigravityUsageDataSource.swift @@ -0,0 +1,30 @@ +import Foundation + +public enum AntigravityUsageDataSource: String, CaseIterable, Identifiable, Sendable { + case auto + case oauth + case cli + + public var id: String { + self.rawValue + } + + public var displayName: String { + switch self { + case .auto: "Auto" + case .oauth: "Google OAuth" + case .cli: "Local IDE API" + } + } + + public var sourceLabel: String { + switch self { + case .auto: + "auto" + case .oauth: + "oauth" + case .cli: + "cli" + } + } +} diff --git a/Tests/CodexBarTests/AntigravityLoginAlertTests.swift b/Tests/CodexBarTests/AntigravityLoginAlertTests.swift new file mode 100644 index 000000000..102d635be --- /dev/null +++ b/Tests/CodexBarTests/AntigravityLoginAlertTests.swift @@ -0,0 +1,34 @@ +import Testing +@testable import CodexBar + +struct AntigravityLoginAlertTests { + @Test + func `returns alert for timeout`() { + let result = AntigravityLoginRunner.Result(outcome: .timedOut) + let info = StatusItemController.antigravityLoginAlertInfo(for: result) + #expect(info?.title == "Antigravity login timed out") + } + + @Test + func `returns alert for launch failure`() { + let result = AntigravityLoginRunner.Result(outcome: .launchFailed("https://example.com/login")) + let info = StatusItemController.antigravityLoginAlertInfo(for: result) + #expect(info?.title == "Could not open browser for Antigravity") + #expect(info?.message.contains("https://example.com/login") == true) + } + + @Test + func `returns alert for auth failure`() { + let result = AntigravityLoginRunner.Result(outcome: .failed("permission denied")) + let info = StatusItemController.antigravityLoginAlertInfo(for: result) + #expect(info?.title == "Antigravity login failed") + #expect(info?.message == "permission denied") + } + + @Test + func `returns nil on success`() { + let result = AntigravityLoginRunner.Result(outcome: .success("user@example.com")) + let info = StatusItemController.antigravityLoginAlertInfo(for: result) + #expect(info == nil) + } +} diff --git a/Tests/CodexBarTests/AntigravityRemoteUsageFetcherTests.swift b/Tests/CodexBarTests/AntigravityRemoteUsageFetcherTests.swift new file mode 100644 index 000000000..642e3e62c --- /dev/null +++ b/Tests/CodexBarTests/AntigravityRemoteUsageFetcherTests.swift @@ -0,0 +1,513 @@ +import CodexBarCore +import Foundation +import Testing + +@Suite(.serialized) +struct AntigravityRemoteUsageFetcherTests { + @Test + func `remote fetch maps cloud code models into antigravity usage`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@company.com", hostedDomain: "company.com"), + email: "user@company.com") + + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "cloudcode-pa.googleapis.com": + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + "cloudaicompanionProject": "managed-project-123", + ])) + } + if url.path == "/v1internal:fetchAvailableModels" { + let body = try #require(request.httpBody) + let json = try #require(JSONSerialization.jsonObject(with: body) as? [String: Any]) + #expect(json["project"] as? String == "managed-project-123") + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: Self.availableModelsResponse()) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + let snapshot = try await fetcher.fetch() + + #expect(snapshot.accountEmail == "user@company.com") + #expect(snapshot.accountPlan == "Paid") + + let usage = try snapshot.toUsageSnapshot() + #expect(usage.primary?.remainingPercent.rounded() == 50) + #expect(usage.secondary?.remainingPercent.rounded() == 80) + #expect(usage.tertiary?.remainingPercent.rounded() == 20) + } + + @Test + func `remote fetch refreshes expired shared google token`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "old-token", + refreshToken: "refresh-token", + expiry: Date().addingTimeInterval(-3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "stale@example.com"), + email: "stale@example.com", + clientID: "test-client-id", + clientSecret: "test-client-secret") + + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "oauth2.googleapis.com": + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "access_token": "new-token", + "expires_in": 3600, + "id_token": GeminiAPITestHelpers.makeIDToken(email: "refreshed@example.com"), + ])) + case "cloudcode-pa.googleapis.com": + let auth = request.value(forHTTPHeaderField: "Authorization") + #expect(auth == "Bearer new-token") + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + "cloudaicompanionProject": "managed-project-123", + ])) + } + if url.path == "/v1internal:fetchAvailableModels" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: Self.availableModelsResponse()) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 2, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + let snapshot = try await fetcher.fetch() + + let updated = try env.readAntigravityCredentials() + #expect(updated["access_token"] as? String == "new-token") + #expect(snapshot.accountEmail == "refreshed@example.com") + } + + @Test + func `remote refresh requires configured oauth client`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "old-token", + refreshToken: "refresh-token", + expiry: Date().addingTimeInterval(-3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@example.com"), + email: "user@example.com") + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: GeminiAPITestHelpers.dataLoader { _ in + throw URLError(.badServerResponse) + }) + + do { + _ = try await fetcher.fetch() + #expect(Bool(false), "Expected missing OAuth client configuration error") + } catch let error as AntigravityRemoteFetchError { + guard case let .apiError(message) = error else { + #expect(Bool(false), "Unexpected Antigravity error: \(error)") + return + } + #expect(message.contains("ANTIGRAVITY_OAUTH_CLIENT_ID")) + } catch { + #expect(Bool(false), "Unexpected error: \(error)") + } + } + + @Test + func `remote fetch onboards project before fetching models`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@example.com"), + email: "user@example.com") + + final class Recorder: @unchecked Sendable { + private let lock = NSLock() + private var projects: [String] = [] + + func append(_ value: String) { + self.lock.lock() + self.projects.append(value) + self.lock.unlock() + } + + func last() -> String? { + self.lock.lock() + defer { self.lock.unlock() } + return self.projects.last + } + } + + let recorder = Recorder() + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "cloudcode-pa.googleapis.com": + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + "allowedTiers": [["id": "standard-tier", "isDefault": true]], + ])) + } + if url.path == "/v1internal:onboardUser" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "response": [ + "cloudaicompanionProject": [ + "id": "onboarded-project-456", + ], + ], + ])) + } + if url.path == "/v1internal:fetchAvailableModels" { + let body = try #require(request.httpBody) + let json = try #require(JSONSerialization.jsonObject(with: body) as? [String: Any]) + if let project = json["project"] as? String { + recorder.append(project) + } + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: Self.availableModelsResponse()) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + _ = try await fetcher.fetch() + + #expect(recorder.last() == "onboarded-project-456") + } + + @Test + func `remote fetch falls back to retrieve user quota when model endpoint is forbidden`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@example.com"), + email: "user@example.com") + + final class Counter: @unchecked Sendable { + private let lock = NSLock() + private var value = 0 + + func increment() { + self.lock.lock() + self.value += 1 + self.lock.unlock() + } + + func get() -> Int { + self.lock.lock() + defer { self.lock.unlock() } + return self.value + } + } + + let quotaCalls = Counter() + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "cloudcode-pa.googleapis.com": + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + "cloudaicompanionProject": "managed-project-123", + ])) + } + if url.path == "/v1internal:fetchAvailableModels" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 403, + body: GeminiAPITestHelpers.jsonData([ + "error": [ + "code": 403, + "message": "The caller does not have permission", + "status": "PERMISSION_DENIED", + ], + ])) + } + if url.path == "/v1internal:retrieveUserQuota" { + quotaCalls.increment() + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.sampleQuotaResponse()) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + let snapshot = try await fetcher.fetch() + let usage = try snapshot.toUsageSnapshot() + + #expect(quotaCalls.get() == 1) + #expect(usage.secondary?.remainingPercent == 60.0) + #expect(usage.tertiary?.remainingPercent == 90.0) + } + + @Test + func `antigravity descriptor advertises oauth mode`() { + let descriptor = ProviderDescriptorRegistry.descriptor(for: .antigravity) + #expect(descriptor.fetchPlan.sourceModes == [.auto, .cli, .oauth]) + } + + @Test + func `remote fetch returns identity when both remote quota endpoints are forbidden`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@example.com"), + email: "user@example.com") + + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "cloudcode-pa.googleapis.com": + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + "cloudaicompanionProject": "managed-project-123", + ])) + } + if url.path == "/v1internal:fetchAvailableModels" || url.path == "/v1internal:retrieveUserQuota" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 403, + body: GeminiAPITestHelpers.jsonData([ + "error": [ + "code": 403, + "message": "The caller does not have permission", + "status": "PERMISSION_DENIED", + ], + ])) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + let snapshot = try await AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + .fetch() + + #expect(snapshot.modelQuotas.isEmpty) + #expect(snapshot.accountEmail == "user@example.com") + #expect(snapshot.accountPlan == "Paid") + } + + @Test + func `remote fetch ignores gemini credentials when antigravity auth is missing`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeCredentials( + accessToken: "gemini-token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "gemini@example.com")) + + let fetcher = AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: GeminiAPITestHelpers.dataLoader { _ in + throw URLError(.badServerResponse) + }) + + await #expect(throws: AntigravityRemoteFetchError.notLoggedIn) { + try await fetcher.fetch() + } + } + + @Test + func `remote fetch prefers stored project id from antigravity credentials`() async throws { + let env = try GeminiTestEnvironment() + defer { env.cleanup() } + try env.writeAntigravityCredentials( + accessToken: "token", + refreshToken: nil, + expiry: Date().addingTimeInterval(3600), + idToken: GeminiAPITestHelpers.makeIDToken(email: "user@example.com"), + email: "user@example.com", + projectID: "stored-project-789") + + final class Recorder: @unchecked Sendable { + private let lock = NSLock() + private var projects: [String] = [] + + func append(_ value: String) { + self.lock.lock() + self.projects.append(value) + self.lock.unlock() + } + + func last() -> String? { + self.lock.lock() + defer { self.lock.unlock() } + return self.projects.last + } + } + + let recorder = Recorder() + let dataLoader = GeminiAPITestHelpers.dataLoader { request in + guard let url = request.url, let host = url.host else { + throw URLError(.badURL) + } + + switch host { + case "cloudcode-pa.googleapis.com": + if url.path == "/v1internal:loadCodeAssist" { + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: GeminiAPITestHelpers.jsonData([ + "currentTier": ["id": "standard-tier", "name": "standard"], + ])) + } + if url.path == "/v1internal:fetchAvailableModels" { + let body = try #require(request.httpBody) + let json = try #require(JSONSerialization.jsonObject(with: body) as? [String: Any]) + if let project = json["project"] as? String { + recorder.append(project) + } + return GeminiAPITestHelpers.response( + url: url.absoluteString, + status: 200, + body: Self.availableModelsResponse()) + } + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + default: + return GeminiAPITestHelpers.response(url: url.absoluteString, status: 404, body: Data()) + } + } + + _ = try await AntigravityRemoteUsageFetcher( + timeout: 1, + homeDirectory: env.homeURL.path, + dataLoader: dataLoader) + .fetch() + + #expect(recorder.last() == "stored-project-789") + } + + private static func availableModelsResponse() -> Data { + GeminiAPITestHelpers.jsonData([ + "models": [ + "claude-sonnet-4": [ + "displayName": "Claude Sonnet 4", + "quotaInfo": [ + "remainingFraction": 0.5, + "resetTime": "2025-01-01T00:00:00Z", + ], + ], + "gemini-3-pro-low": [ + "displayName": "Gemini 3 Pro Low", + "quotaInfo": [ + "remainingFraction": 0.8, + "resetTime": "2025-01-01T00:00:00Z", + ], + ], + "gemini-3-flash": [ + "displayName": "Gemini 3 Flash", + "quotaInfo": [ + "remainingFraction": 0.2, + "resetTime": "2025-01-01T00:00:00Z", + ], + ], + "gemini-3-flash-lite": [ + "displayName": "Gemini 3 Flash Lite", + "quotaInfo": [ + "remainingFraction": 0.7, + "resetTime": "2025-01-01T00:00:00Z", + ], + ], + ], + ]) + } +} diff --git a/Tests/CodexBarTests/AntigravityStatusProbeTests.swift b/Tests/CodexBarTests/AntigravityStatusProbeTests.swift index 8946f630d..0cffdd59f 100644 --- a/Tests/CodexBarTests/AntigravityStatusProbeTests.swift +++ b/Tests/CodexBarTests/AntigravityStatusProbeTests.swift @@ -175,6 +175,30 @@ struct AntigravityStatusProbeTests { #expect(usage.secondary?.remainingPercent.rounded() == 90) } + @Test + func `gemini pro prefers model with remaining data over low priority placeholder`() throws { + let snapshot = AntigravityStatusSnapshot( + modelQuotas: [ + AntigravityModelQuota( + label: "Gemini 3 Pro (Low)", + modelId: "MODEL_PLACEHOLDER_M36", + remainingFraction: nil, + resetTime: Date(timeIntervalSince1970: 1_735_000_000), + resetDescription: nil), + AntigravityModelQuota( + label: "Gemini 3 Pro (High)", + modelId: "MODEL_PLACEHOLDER_M37", + remainingFraction: 1, + resetTime: Date(timeIntervalSince1970: 1_735_100_000), + resetDescription: nil), + ], + accountEmail: nil, + accountPlan: nil) + + let usage = try snapshot.toUsageSnapshot() + #expect(usage.secondary?.remainingPercent.rounded() == 100) + } + @Test func `gemini flash does not fallback to lite variant`() throws { let snapshot = AntigravityStatusSnapshot( @@ -232,6 +256,58 @@ struct AntigravityStatusProbeTests { #expect(usage.tertiary?.remainingPercent.rounded() == 100) } + @Test + func `matches remote antigravity model names with parentheses`() throws { + let resetTime = Date(timeIntervalSince1970: 1_775_000_000) + let snapshot = AntigravityStatusSnapshot( + modelQuotas: [ + AntigravityModelQuota( + label: "Claude Opus 4.6 (Thinking)", + modelId: "MODEL_PLACEHOLDER_M50", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + AntigravityModelQuota( + label: "Claude Sonnet 4.6 (Thinking)", + modelId: "MODEL_PLACEHOLDER_M51", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + AntigravityModelQuota( + label: "Gemini 3 Pro (High)", + modelId: "MODEL_PLACEHOLDER_M52", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + AntigravityModelQuota( + label: "Gemini 3 Pro (Low)", + modelId: "MODEL_PLACEHOLDER_M53", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + AntigravityModelQuota( + label: "Gemini 3 Flash", + modelId: "MODEL_PLACEHOLDER_M54", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + AntigravityModelQuota( + label: "GPT-OSS 120B (Medium)", + modelId: "MODEL_PLACEHOLDER_M55", + remainingFraction: 1, + resetTime: resetTime, + resetDescription: nil), + ], + accountEmail: "user@example.com", + accountPlan: "Pro") + + let usage = try snapshot.toUsageSnapshot() + #expect(usage.primary?.remainingPercent.rounded() == 100) + #expect(usage.secondary?.remainingPercent.rounded() == 100) + #expect(usage.tertiary?.remainingPercent.rounded() == 100) + #expect(usage.identity?.accountEmail == "user@example.com") + } + @Test func `model without remaining fraction keeps reset time`() throws { let resetTime = Date(timeIntervalSince1970: 1_735_000_000) diff --git a/Tests/CodexBarTests/AppDelegateTests.swift b/Tests/CodexBarTests/AppDelegateTests.swift index c8b784c9b..aa33aa9e7 100644 --- a/Tests/CodexBarTests/AppDelegateTests.swift +++ b/Tests/CodexBarTests/AppDelegateTests.swift @@ -49,4 +49,5 @@ struct AppDelegateTests { @MainActor private final class DummyStatusController: StatusItemControlling { func openMenuFromShortcut() {} + func runLoginFlowFromSettings(provider _: UsageProvider) async {} } diff --git a/Tests/CodexBarTests/GeminiTestEnvironment.swift b/Tests/CodexBarTests/GeminiTestEnvironment.swift index 3d6b0b4bb..45eedfd2f 100644 --- a/Tests/CodexBarTests/GeminiTestEnvironment.swift +++ b/Tests/CodexBarTests/GeminiTestEnvironment.swift @@ -8,14 +8,20 @@ struct GeminiTestEnvironment { let homeURL: URL private let geminiDir: URL + private let antigravityDir: URL init() throws { let root = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) try FileManager.default.createDirectory(at: root, withIntermediateDirectories: true) let geminiDir = root.appendingPathComponent(".gemini") try FileManager.default.createDirectory(at: geminiDir, withIntermediateDirectories: true) + let antigravityDir = root + .appendingPathComponent(".codexbar") + .appendingPathComponent("antigravity") + try FileManager.default.createDirectory(at: antigravityDir, withIntermediateDirectories: true) self.homeURL = root self.geminiDir = geminiDir + self.antigravityDir = antigravityDir } func cleanup() { @@ -52,6 +58,37 @@ struct GeminiTestEnvironment { return object as? [String: Any] ?? [:] } + func writeAntigravityCredentials( + accessToken: String, + refreshToken: String?, + expiry: Date, + idToken: String? = nil, + email: String? = nil, + projectID: String? = nil, + clientID: String? = nil, + clientSecret: String? = nil) throws + { + var payload: [String: Any] = [ + "access_token": accessToken, + "expiry_date": expiry.timeIntervalSince1970 * 1000, + ] + if let refreshToken { payload["refresh_token"] = refreshToken } + if let idToken { payload["id_token"] = idToken } + if let email { payload["email"] = email } + if let projectID { payload["project_id"] = projectID } + if let clientID { payload["client_id"] = clientID } + if let clientSecret { payload["client_secret"] = clientSecret } + let data = try JSONSerialization.data(withJSONObject: payload) + try data.write(to: self.antigravityDir.appendingPathComponent("oauth_creds.json"), options: .atomic) + } + + func readAntigravityCredentials() throws -> [String: Any] { + let url = self.antigravityDir.appendingPathComponent("oauth_creds.json") + let data = try Data(contentsOf: url) + let object = try JSONSerialization.jsonObject(with: data) + return object as? [String: Any] ?? [:] + } + func writeFakeGeminiCLI(includeOAuth: Bool = true, layout: GeminiCLILayout = .npmNested) throws -> URL { let base = self.homeURL.appendingPathComponent("gemini-cli") let binDir = base.appendingPathComponent("bin") diff --git a/Tests/CodexBarTests/ProviderSettingsDescriptorTests.swift b/Tests/CodexBarTests/ProviderSettingsDescriptorTests.swift index 924da2a05..78c0ab8cb 100644 --- a/Tests/CodexBarTests/ProviderSettingsDescriptorTests.swift +++ b/Tests/CodexBarTests/ProviderSettingsDescriptorTests.swift @@ -59,7 +59,8 @@ struct ProviderSettingsDescriptorTests { lastRunAtByID.removeValue(forKey: id) } }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let impl = try #require(ProviderCatalog.implementation(for: provider)) let toggles = impl.settingsToggles(context: context) @@ -115,7 +116,8 @@ struct ProviderSettingsDescriptorTests { setStatusText: { _, _ in }, lastAppActiveRunAt: { _ in nil }, setLastAppActiveRunAt: { _, _ in }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let pickers = CodexProviderImplementation().settingsPickers(context: context) let toggles = CodexProviderImplementation().settingsToggles(context: context) @@ -159,12 +161,13 @@ struct ProviderSettingsDescriptorTests { setStatusText: { _, _ in }, lastAppActiveRunAt: { _ in nil }, setLastAppActiveRunAt: { _, _ in }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let pickers = ClaudeProviderImplementation().settingsPickers(context: context) #expect(pickers.contains(where: { $0.id == "claude-usage-source" })) #expect(pickers.contains(where: { $0.id == "claude-cookie-source" })) let keychainPicker = try #require(pickers.first(where: { $0.id == "claude-keychain-prompt-policy" })) - let optionIDs = Set(keychainPicker.options.map(\.id)) + let optionIDs = Set(keychainPicker.options.map { $0.id }) #expect(optionIDs.contains(ClaudeOAuthKeychainPromptMode.never.rawValue)) #expect(optionIDs.contains(ClaudeOAuthKeychainPromptMode.onlyOnUserAction.rawValue)) #expect(optionIDs.contains(ClaudeOAuthKeychainPromptMode.always.rawValue)) @@ -208,7 +211,8 @@ struct ProviderSettingsDescriptorTests { setStatusText: { _, _ in }, lastAppActiveRunAt: { _ in nil }, setLastAppActiveRunAt: { _, _ in }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let pickers = ClaudeProviderImplementation().settingsPickers(context: context) let keychainPicker = try #require(pickers.first(where: { $0.id == "claude-keychain-prompt-policy" })) @@ -250,7 +254,8 @@ struct ProviderSettingsDescriptorTests { setStatusText: { _, _ in }, lastAppActiveRunAt: { _ in nil }, setLastAppActiveRunAt: { _, _ in }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let pickers = ClaudeProviderImplementation().settingsPickers(context: context) let keychainPicker = try #require(pickers.first(where: { $0.id == "claude-keychain-prompt-policy" })) @@ -312,7 +317,8 @@ struct ProviderSettingsDescriptorTests { setStatusText: { _, _ in }, lastAppActiveRunAt: { _ in nil }, setLastAppActiveRunAt: { _, _ in }, - requestConfirmation: { _ in }) + requestConfirmation: { _ in }, + runLoginFlow: {}) let implementation = KiloProviderImplementation() let toggles = implementation.settingsToggles(context: context)