From d8b52184b68656cf192826d11a4271051a8712ae Mon Sep 17 00:00:00 2001 From: kchro3 Date: Thu, 14 Sep 2023 23:57:08 -0700 Subject: [PATCH] wip --- TypeaheadAI/Actors/SpecialCopyActor.swift | 54 ++++++++-------- TypeaheadAI/Actors/SpecialCutActor.swift | 12 ++-- TypeaheadAI/Actors/SpecialSaveActor.swift | 5 +- TypeaheadAI/AppContextManager.swift | 26 ++++++-- TypeaheadAI/ClientManager.swift | 77 ++++++++--------------- TypeaheadAI/ModalManager.swift | 26 ++++++-- 6 files changed, 104 insertions(+), 96 deletions(-) diff --git a/TypeaheadAI/Actors/SpecialCopyActor.swift b/TypeaheadAI/Actors/SpecialCopyActor.swift index 704b4dd..a7b541f 100644 --- a/TypeaheadAI/Actors/SpecialCopyActor.swift +++ b/TypeaheadAI/Actors/SpecialCopyActor.swift @@ -33,41 +33,37 @@ actor SpecialCopyActor: CanSimulateCopy { } self.logger.debug("copied '\(copiedText)'") + // Clear the modal text and reissue request - self.modalManager.clearText(stickyMode: stickyMode) self.modalManager.showModal(incognito: incognitoMode) - var truncated: String = copiedText - if (copiedText.count > 280) { - truncated = "\(truncated.prefix(280))..." - } - if let activePrompt = self.clientManager.getActivePrompt() { - self.modalManager.setUserMessage("\(activePrompt)\n:\(truncated)") - } else { - self.modalManager.setUserMessage("copied:\n\(truncated)") - } + Task { + await self.modalManager.clearText(stickyMode: stickyMode) + if let activePrompt = self.clientManager.getActivePrompt() { + await self.modalManager.setUserMessage("\(activePrompt):\n\(copiedText)") + } else { + await self.modalManager.setUserMessage("copied:\n\(copiedText)") + } - self.clientManager.predict( - id: UUID(), - copiedText: copiedText, - incognitoMode: incognitoMode, - stream: true, - streamHandler: { result in - switch result { - case .success(let chunk): - Task { - await self.modalManager.appendText(chunk) + self.clientManager.refine( + messages: self.modalManager.messages, + incognitoMode: incognitoMode, + streamHandler: { result in + switch result { + case .success(let chunk): + Task { + await self.modalManager.appendText(chunk) + } + self.logger.info("Received chunk: \(chunk)") + case .failure(let error): + DispatchQueue.main.async { + self.modalManager.setError(error.localizedDescription) + } + self.logger.error("An error occurred: \(error)") } - self.logger.info("Received chunk: \(chunk)") - case .failure(let error): - DispatchQueue.main.async { - self.modalManager.setError(error.localizedDescription) - } - self.logger.error("An error occurred: \(error)") } - }, - completion: { _ in } - ) + ) + } } } } diff --git a/TypeaheadAI/Actors/SpecialCutActor.swift b/TypeaheadAI/Actors/SpecialCutActor.swift index dc85b2f..4b6fcd1 100644 --- a/TypeaheadAI/Actors/SpecialCutActor.swift +++ b/TypeaheadAI/Actors/SpecialCutActor.swift @@ -102,13 +102,15 @@ actor SpecialCutActor { truncated = "\(truncated.prefix(280))..." } - self.modalManager.clearText(stickyMode: stickyMode) self.modalManager.showModal(incognito: incognitoMode) - if let activePrompt = self.clientManager.getActivePrompt() { - self.modalManager.setUserMessage("\(activePrompt)\n:\(truncated)") - } else { - self.modalManager.setUserMessage("cut:\n\(truncated)") + Task { + await self.modalManager.clearText(stickyMode: stickyMode) + if let activePrompt = self.clientManager.getActivePrompt() { + await self.modalManager.setUserMessage("\(activePrompt)\n:\(truncated)") + } else { + await self.modalManager.setUserMessage("cut:\n\(truncated)") + } } self.clientManager.predict( diff --git a/TypeaheadAI/Actors/SpecialSaveActor.swift b/TypeaheadAI/Actors/SpecialSaveActor.swift index b732dc7..0bc35c6 100644 --- a/TypeaheadAI/Actors/SpecialSaveActor.swift +++ b/TypeaheadAI/Actors/SpecialSaveActor.swift @@ -37,7 +37,10 @@ actor SpecialSaveActor: CanSimulateCopy { self.logger.debug("saved '\(copiedText)'") // Force sticky-mode so that it saves the message to the session. - self.modalManager.clearText(stickyMode: true) + Task { + await self.modalManager.clearText(stickyMode: true) + } + self.modalManager.showModal(incognito: incognitoMode) Task { diff --git a/TypeaheadAI/AppContextManager.swift b/TypeaheadAI/AppContextManager.swift index b74bd21..3bd2b4d 100644 --- a/TypeaheadAI/AppContextManager.swift +++ b/TypeaheadAI/AppContextManager.swift @@ -17,7 +17,7 @@ class AppContextManager { category: "AppContextManager" ) - func getActiveAppInfo(completion: @escaping (String?, String?, String?) -> Void) { + func getContext(completion: @escaping (AppContext) -> Void) { self.logger.debug("get active app") if let activeApp = NSWorkspace.shared.frontmostApplication { let appName = activeApp.localizedName @@ -28,17 +28,33 @@ class AppContextManager { self.scriptManager.executeScript { (result, error) in if let error = error { self.logger.error("Failed to execute script: \(error.errorDescription ?? "Unknown error")") - completion(appName, bundleIdentifier, nil) + completion(AppContext( + activeAppName: appName, + activeAppBundleIdentifier: bundleIdentifier, + url: nil + )) } else if let url = result?.stringValue { self.logger.info("Successfully executed script. URL: \(url)") - completion(appName, bundleIdentifier, url) + completion(AppContext( + activeAppName: appName, + activeAppBundleIdentifier: bundleIdentifier, + url: url + )) } } } else { - completion(appName, bundleIdentifier, nil) + completion(AppContext( + activeAppName: appName, + activeAppBundleIdentifier: bundleIdentifier, + url: nil + )) } } else { - completion(nil, nil, nil) + completion(AppContext( + activeAppName: nil, + activeAppBundleIdentifier: nil, + url: nil + )) } } } diff --git a/TypeaheadAI/ClientManager.swift b/TypeaheadAI/ClientManager.swift index 911472f..8f01a37 100644 --- a/TypeaheadAI/ClientManager.swift +++ b/TypeaheadAI/ClientManager.swift @@ -15,11 +15,9 @@ struct RequestPayload: Codable { var userObjective: String var userBio: String var userLang: String - var copiedText: String +// var copiedText: String? = nil var messages: [Message]? - var url: String - var activeAppName: String - var activeAppBundleIdentifier: String + var appContext: AppContext? var onboarding: Bool = false } @@ -50,8 +48,8 @@ class ClientManager { private let session: URLSession private let apiUrl = URL(string: "https://typeahead-ai.fly.dev/get_response")! - private let apiUrlStreaming = URL(string: "https://typeahead-ai.fly.dev/get_response_stream")! -// private let apiUrlStreaming = URL(string: "http://localhost:8080/get_response_stream")! +// private let apiUrlStreaming = URL(string: "https://typeahead-ai.fly.dev/get_response_stream")! + private let apiUrlStreaming = URL(string: "http://localhost:8080/get_response_stream")! private let logger = Logger( subsystem: "ai.typeahead.TypeaheadAI", @@ -84,7 +82,7 @@ class ClientManager { // If objective is not specified in the request, fall back on the active prompt. let objective = userObjective ?? self.promptManager?.getActivePrompt() ?? (stream ? "respond to this in <20 words" : "respond to this") - appContextManager!.getActiveAppInfo { (appName, bundleIdentifier, url) in + appContextManager!.getContext { appContext in if stream { Task { await self.sendStreamRequest( @@ -96,9 +94,7 @@ class ClientManager { userLang: Locale.preferredLanguages.first ?? "", copiedText: copiedText, messages: [], - url: url ?? "", - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: incognitoMode, streamHandler: streamHandler, completion: completion @@ -114,9 +110,7 @@ class ClientManager { userBio: UserDefaults.standard.string(forKey: "bio") ?? "", userLang: Locale.preferredLanguages.first ?? "", copiedText: copiedText, - url: url ?? "", - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: incognitoMode, completion: completion ) @@ -135,7 +129,7 @@ class ClientManager { if let (key, _) = cached, let data = key.data(using: .utf8), let payload = try? JSONDecoder().decode(RequestPayload.self, from: data) { - appContextManager!.getActiveAppInfo { (appName, bundleIdentifier, url) in + appContextManager!.getContext { appContext in Task { await self.sendStreamRequest( id: UUID(), @@ -144,11 +138,9 @@ class ClientManager { userObjective: payload.userObjective, userBio: payload.userBio, userLang: payload.userLang, - copiedText: payload.copiedText, + copiedText: nil, // payload.copiedText, messages: self.sanitizeMessages(messages), - url: payload.url, - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: incognitoMode, streamHandler: streamHandler, completion: { _ in } @@ -157,7 +149,7 @@ class ClientManager { } } else { logger.error("No cached request to refine") - appContextManager!.getActiveAppInfo { (appName, bundleIdentifier, url) in + appContextManager!.getContext { appContext in Task { await self.sendStreamRequest( id: UUID(), @@ -168,9 +160,7 @@ class ClientManager { userLang: Locale.preferredLanguages.first ?? "", copiedText: "", messages: self.sanitizeMessages(messages), - url: url ?? "unknown", - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: false, streamHandler: streamHandler, completion: { _ in } @@ -187,7 +177,7 @@ class ClientManager { streamHandler: @escaping (Result) -> Void ) { if messages.isEmpty { - appContextManager!.getActiveAppInfo { (appName, bundleIdentifier, url) in + appContextManager!.getContext { appContext in Task { await self.sendStreamRequest( id: UUID(), @@ -198,9 +188,7 @@ class ClientManager { userLang: Locale.preferredLanguages.first ?? "", copiedText: "", messages: self.sanitizeMessages(messages), - url: url ?? "unknown", - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: false, onboardingMode: true, streamHandler: streamHandler, @@ -217,7 +205,7 @@ class ClientManager { return } - appContextManager!.getActiveAppInfo { (appName, bundleIdentifier, url) in + appContextManager!.getContext { appContext in Task { await self.sendStreamRequest( id: UUID(), @@ -226,11 +214,9 @@ class ClientManager { userObjective: payload.userObjective, userBio: payload.userBio, userLang: payload.userLang, - copiedText: payload.copiedText, + copiedText: nil, // payload.copiedText, messages: self.sanitizeMessages(messages), - url: payload.url, - activeAppName: appName ?? "unknown", - activeAppBundleIdentifier: bundleIdentifier ?? "", + appContext: appContext, incognitoMode: false, onboardingMode: true, streamHandler: streamHandler, @@ -251,9 +237,7 @@ class ClientManager { /// - userBio: Details about the user. /// - userLang: User's preferred language. /// - copiedText: The text that the user has copied. - /// - url: The URL that the user is currently viewing. - /// - activeAppName: The name of the app that is currently active. - /// - activeAppBundleIdentifier: The bundle identifier of the currently active app. + /// - appContext: Currently active app context /// - incognitoMode: Whether or not the request is sent to an online or offline model. /// - timeout: The timeout for the request. Default is 10 seconds. /// - completion: A closure to be executed once the request is complete. @@ -265,9 +249,7 @@ class ClientManager { userBio: String, userLang: String, copiedText: String, - url: String, - activeAppName: String, - activeAppBundleIdentifier: String, + appContext: AppContext, incognitoMode: Bool, timeout: TimeInterval = 10, completion: @escaping (Result) -> Void @@ -278,10 +260,7 @@ class ClientManager { userObjective: userObjective, userBio: userBio, userLang: userLang, - copiedText: copiedText, - url: url, - activeAppName: activeAppName, - activeAppBundleIdentifier: activeAppBundleIdentifier + appContext: appContext ) if (incognitoMode) { @@ -345,11 +324,9 @@ class ClientManager { userObjective: String, userBio: String, userLang: String, - copiedText: String, + copiedText: String?, messages: [Message], - url: String, - activeAppName: String, - activeAppBundleIdentifier: String, + appContext: AppContext, incognitoMode: Bool, onboardingMode: Bool = false, timeout: TimeInterval = 10, @@ -364,11 +341,9 @@ class ClientManager { userObjective: userObjective, userBio: userBio, userLang: userLang, - copiedText: copiedText, +// copiedText: copiedText, messages: self?.sanitizeMessages(messages), - url: url, - activeAppName: activeAppName, - activeAppBundleIdentifier: activeAppBundleIdentifier, + appContext: appContext, onboarding: onboardingMode ) @@ -455,9 +430,7 @@ class ClientManager { do { var payloadCopy = payload - payloadCopy.url = "" - payloadCopy.activeAppName = "" - payloadCopy.activeAppBundleIdentifier = "" + payloadCopy.appContext = nil let jsonData = try encoder.encode(payloadCopy) if let jsonString = String(data: jsonData, encoding: .utf8) { diff --git a/TypeaheadAI/ModalManager.swift b/TypeaheadAI/ModalManager.swift index 57292e7..474670a 100644 --- a/TypeaheadAI/ModalManager.swift +++ b/TypeaheadAI/ModalManager.swift @@ -16,6 +16,12 @@ struct AttributedOutput: Codable, Equatable { let results: [ParserResult] } +struct AppContext: Codable, Equatable { + var activeAppName: String? + var activeAppBundleIdentifier: String? + var url: String? +} + // TODO: Add to persistence struct Message: Codable, Identifiable, Equatable { let id: UUID @@ -23,6 +29,7 @@ struct Message: Codable, Identifiable, Equatable { var attributed: AttributedOutput? = nil let isCurrentUser: Bool var responseError: String? + var appContext: AppContext? } class ModalManager: ObservableObject { @@ -66,6 +73,7 @@ class ModalManager: ObservableObject { self.triggerFocus = true } + @MainActor func clearText(stickyMode: Bool) { if stickyMode { // TODO: Should we do something smarter here? @@ -98,6 +106,7 @@ class ModalManager: ObservableObject { } /// Set an error message. + @MainActor func setError(_ responseError: String) { if let idx = messages.indices.last, !messages[idx].isCurrentUser { messages[idx].responseError = responseError @@ -187,11 +196,22 @@ class ModalManager: ObservableObject { } /// Add a user message without flushing the modal text. Use this when there is an active prompt. + @MainActor func setUserMessage(_ text: String) { - messages.append(Message(id: UUID(), text: text, isCurrentUser: true)) + self.clientManager?.appContextManager?.getContext { appContext in + self.messages.append( + Message( + id: UUID(), + text: text, + isCurrentUser: true, + appContext: appContext + ) + ) + } } /// When a user responds, flush the current text to the messages array and add the system and user prompts + @MainActor func addUserMessage(_ text: String, incognito: Bool) { self.clientManager?.cancelStreamingTask() @@ -205,9 +225,7 @@ class ModalManager: ObservableObject { } self.logger.info("Received chunk: \(chunk)") case .failure(let error): - Task { - self.setError(error.localizedDescription) - } + self.setError(error.localizedDescription) self.logger.error("An error occurred: \(error)") } }