diff --git a/mac/FreeChat.xcodeproj/project.pbxproj b/mac/FreeChat.xcodeproj/project.pbxproj index 79d473c..d9c6c52 100644 --- a/mac/FreeChat.xcodeproj/project.pbxproj +++ b/mac/FreeChat.xcodeproj/project.pbxproj @@ -58,7 +58,13 @@ A1F617562A782E4F00F2048C /* ConversationView.swift in Sources */ = {isa = PBXBuildFile; fileRef = A1F617552A782E4F00F2048C /* ConversationView.swift */; }; A1F617582A7836AE00F2048C /* Message+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = A1F617572A7836AE00F2048C /* Message+Extensions.swift */; }; A1F6175B2A7838F700F2048C /* Conversation+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = A1F6175A2A7838F700F2048C /* Conversation+Extensions.swift */; }; + DE16617B2B8A40D100826556 /* OpenAIBackend.swift in Sources */ = {isa = PBXBuildFile; fileRef = DE16617A2B8A40D100826556 /* OpenAIBackend.swift */; }; + DE7250E12B966D23006A76DF /* String+TrimQuotes.swift in Sources */ = {isa = PBXBuildFile; fileRef = DE7250E02B966D22006A76DF /* String+TrimQuotes.swift */; }; DEA8CF572B51938B007A4CE7 /* FreeChatAppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEA8CF562B51938B007A4CE7 /* FreeChatAppDelegate.swift */; }; + DEAE3D482B987DE700257A69 /* Backend.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEAE3D472B987DE700257A69 /* Backend.swift */; }; + DEAE3D4A2B987EA400257A69 /* OllamaBackend.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEAE3D492B987EA400257A69 /* OllamaBackend.swift */; }; + DEAE3D4C2B987EB300257A69 /* LlamaBackend.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEAE3D4B2B987EB300257A69 /* LlamaBackend.swift */; }; + DEAE3D4E2B987EBC00257A69 /* LocalBackend.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEAE3D4D2B987EBC00257A69 /* LocalBackend.swift */; }; DEEA39CC2B586F3800992592 /* ServerHealth.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEEA39CB2B586F3800992592 /* ServerHealth.swift */; }; /* End PBXBuildFile section */ @@ -176,7 +182,13 @@ A1F617552A782E4F00F2048C /* ConversationView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationView.swift; sourceTree = ""; }; A1F617572A7836AE00F2048C /* Message+Extensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Message+Extensions.swift"; sourceTree = ""; }; A1F6175A2A7838F700F2048C /* Conversation+Extensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Conversation+Extensions.swift"; sourceTree = ""; }; + DE16617A2B8A40D100826556 /* OpenAIBackend.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = OpenAIBackend.swift; path = FreeChat/Models/NPC/OpenAIBackend.swift; sourceTree = SOURCE_ROOT; }; + DE7250E02B966D22006A76DF /* String+TrimQuotes.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "String+TrimQuotes.swift"; sourceTree = ""; }; DEA8CF562B51938B007A4CE7 /* FreeChatAppDelegate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FreeChatAppDelegate.swift; sourceTree = ""; }; + DEAE3D472B987DE700257A69 /* Backend.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Backend.swift; sourceTree = ""; }; + DEAE3D492B987EA400257A69 /* OllamaBackend.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OllamaBackend.swift; sourceTree = ""; }; + DEAE3D4B2B987EB300257A69 /* LlamaBackend.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LlamaBackend.swift; sourceTree = ""; }; + DEAE3D4D2B987EBC00257A69 /* LocalBackend.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalBackend.swift; sourceTree = ""; }; DEEA39CB2B586F3800992592 /* ServerHealth.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ServerHealth.swift; sourceTree = ""; }; /* End PBXFileReference section */ @@ -259,7 +271,13 @@ A137A3872AB502DB00BE1AE0 /* ggml-metal.metal */, A17A2E122A79A005006CDD90 /* Agent.swift */, A17A2E132A79A005006CDD90 /* LlamaServer.swift */, + DEAE3D472B987DE700257A69 /* Backend.swift */, + DE16617A2B8A40D100826556 /* OpenAIBackend.swift */, + DEAE3D492B987EA400257A69 /* OllamaBackend.swift */, + DEAE3D4B2B987EB300257A69 /* LlamaBackend.swift */, + DEAE3D4D2B987EBC00257A69 /* LocalBackend.swift */, DEEA39CB2B586F3800992592 /* ServerHealth.swift */, + DE7250E02B966D22006A76DF /* String+TrimQuotes.swift */, A137A3822AB4FD4800BE1AE0 /* freechat-server */, A1A286F32A7E17750004967A /* server-watchdog */, A1A286F92A7E197F0004967A /* README.md */, @@ -577,12 +595,17 @@ A1F617582A7836AE00F2048C /* Message+Extensions.swift in Sources */, A13C8C682A902A1200EC18D8 /* CGKeycode+Extensions.swift in Sources */, A15D50D22A7F539800FC1681 /* NavList.swift in Sources */, + DEAE3D4C2B987EB300257A69 /* LlamaBackend.swift in Sources */, + DE16617B2B8A40D100826556 /* OpenAIBackend.swift in Sources */, A1156D342AD1F5EF00081313 /* Templates.swift in Sources */, A1D4B49D2B9A780B00B9C4BE /* AgentDefaults.swift in Sources */, A1F617262A782AA100F2048C /* FreeChat.swift in Sources */, A1156D2F2AD0954C00081313 /* TemplateManager.swift in Sources */, A1E4A6942A82B41F00BF9D34 /* Model+Extensions.swift in Sources */, + DEAE3D4A2B987EA400257A69 /* OllamaBackend.swift in Sources */, A12B52DE2AA5228100658707 /* EditModels.swift in Sources */, + DEAE3D482B987DE700257A69 /* Backend.swift in Sources */, + DEAE3D4E2B987EBC00257A69 /* LocalBackend.swift in Sources */, A17AB1C22ABB4B5E00CD3100 /* CircleMenuStyle.swift in Sources */, A15D50CF2A7EF73E00FC1681 /* MessageTextField.swift in Sources */, A1CA32442AAF877600F9D488 /* ConversationManager.swift in Sources */, @@ -595,6 +618,7 @@ A16FFF8B2B2E35D200E6AAE2 /* GPU.swift in Sources */, A18A8BB32B24FC0400D2197C /* AISettingsView.swift in Sources */, A1F617562A782E4F00F2048C /* ConversationView.swift in Sources */, + DE7250E12B966D23006A76DF /* String+TrimQuotes.swift in Sources */, A13C8C5A2A8FEEE400EC18D8 /* SplashCodeSyntaxHighlighter.swift in Sources */, A15D50D42A80BCA900FC1681 /* SettingsView.swift in Sources */, ); diff --git a/mac/FreeChat.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mac/FreeChat.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 985f0f9..3d62681 100644 --- a/mac/FreeChat.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mac/FreeChat.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/Recouse/EventSource.git", "state" : { - "revision" : "fcd7152a3106d75287c7303bba40a4761e5b7f6d", - "version" : "0.0.5" + "revision" : "ffaa978620b19c891d107941c1b36d18836e8ecb", + "version" : "0.0.7" } }, { @@ -27,24 +27,6 @@ "version" : "0.16.0" } }, - { - "identity" : "swift-async-algorithms", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-async-algorithms.git", - "state" : { - "revision" : "9cfed92b026c524674ed869a4ff2dcfdeedf8a2a", - "version" : "0.1.0" - } - }, - { - "identity" : "swift-collections", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-collections.git", - "state" : { - "revision" : "937e904258d22af6e447a0b72c0bc67583ef64a2", - "version" : "1.0.4" - } - }, { "identity" : "swift-markdown-ui", "kind" : "remoteSourceControl", diff --git a/mac/FreeChat/Chats.xcdatamodeld/Mantras.xcdatamodel/contents b/mac/FreeChat/Chats.xcdatamodeld/Mantras.xcdatamodel/contents index 2b870f4..1066110 100644 --- a/mac/FreeChat/Chats.xcdatamodeld/Mantras.xcdatamodel/contents +++ b/mac/FreeChat/Chats.xcdatamodeld/Mantras.xcdatamodel/contents @@ -1,5 +1,16 @@ - + + + + + + + + + + + + @@ -37,4 +48,4 @@ - \ No newline at end of file + diff --git a/mac/FreeChat/FreeChatAppDelegate.swift b/mac/FreeChat/FreeChatAppDelegate.swift index c63b9f8..718fcc8 100644 --- a/mac/FreeChat/FreeChatAppDelegate.swift +++ b/mac/FreeChat/FreeChatAppDelegate.swift @@ -7,21 +7,25 @@ import SwiftUI class FreeChatAppDelegate: NSObject, NSApplicationDelegate, ObservableObject { @AppStorage("selectedModelId") private var selectedModelId: String? - + @AppStorage("backendTypeID") private var backendTypeID: String = BackendType.local.rawValue + func application(_ application: NSApplication, open urls: [URL]) { + backendTypeID = BackendType.local.rawValue let viewContext = PersistenceController.shared.container.viewContext do { let req = Model.fetchRequest() req.predicate = NSPredicate(format: "name IN %@", urls.map({ $0.lastPathComponent })) - let existingModels = try viewContext.fetch(req).compactMap({ $0.url }) + let existingModels = try viewContext.fetch(req) for url in urls { - guard !existingModels.contains(url) else { continue } + guard !existingModels.compactMap({ $0.url }).contains(url) else { continue } let insertedModel = try Model.create(context: viewContext, fileURL: url) selectedModelId = insertedModel.id?.uuidString } - - NotificationCenter.default.post(name: NSNotification.Name("selectedModelDidChange"), object: selectedModelId) + + if urls.count == 1, let modelID = existingModels.first(where: { $0.url == urls.first })?.id?.uuidString { selectedModelId = modelID } + + NotificationCenter.default.post(name: NSNotification.Name("selectedLocalModelDidChange"), object: selectedModelId) NotificationCenter.default.post(name: NSNotification.Name("needStartNewConversation"), object: selectedModelId) } catch { print("error saving model:", error) diff --git a/mac/FreeChat/Models/ConversationManager.swift b/mac/FreeChat/Models/ConversationManager.swift index 4bb6f8b..33b3e35 100644 --- a/mac/FreeChat/Models/ConversationManager.swift +++ b/mac/FreeChat/Models/ConversationManager.swift @@ -15,6 +15,7 @@ class ConversationManager: ObservableObject { var summonRegistered = false + @AppStorage("backendTypeID") private var backendTypeID: String? @AppStorage("systemPrompt") private var systemPrompt: String = DEFAULT_SYSTEM_PROMPT @AppStorage("contextLength") private var contextLength: Int = DEFAULT_CONTEXT_LENGTH @@ -72,10 +73,8 @@ class ConversationManager: ObservableObject { @MainActor func rebootAgent(systemPrompt: String? = nil, model: Model, viewContext: NSManagedObjectContext) { + guard let url = model.url else { return } let systemPrompt = systemPrompt ?? self.systemPrompt - guard let url = model.url else { - return - } Task { await agent.llama.stopServer() @@ -83,12 +82,25 @@ class ConversationManager: ObservableObject { let messages = currentConversation.orderedMessages.map { $0.text ?? "" } let convoPrompt = model.template.run(systemPrompt: systemPrompt, messages: messages) agent = Agent(id: "Llama", prompt: convoPrompt, systemPrompt: systemPrompt, modelPath: url.path, contextLength: contextLength) - loadingModelId = model.id?.uuidString - model.error = nil + do { + let backendType: BackendType = BackendType(rawValue: backendTypeID ?? "") ?? .local + let context = PersistenceController.shared.container.newBackgroundContext() + let config = try fetchBackendConfig(context: context) ?? BackendConfig(context: context) + agent.createBackend(backendType, contextLength: contextLength, config: config) + } catch { print("error fetching backend config", error) } + loadingModelId = model.id?.uuidString + model.error = nil loadingModelId = nil try? viewContext.save() } } + + private func fetchBackendConfig(context: NSManagedObjectContext) throws -> BackendConfig? { + let backendType: BackendType = BackendType(rawValue: backendTypeID ?? "") ?? .local + let req = BackendConfig.fetchRequest() + req.predicate = NSPredicate(format: "backendType == %@", backendType.rawValue) + return try context.fetch(req).first + } } diff --git a/mac/FreeChat/Models/NPC/Agent.swift b/mac/FreeChat/Models/NPC/Agent.swift index 79e974d..f87dbfc 100644 --- a/mac/FreeChat/Models/NPC/Agent.swift +++ b/mac/FreeChat/Models/NPC/Agent.swift @@ -21,42 +21,44 @@ class Agent: ObservableObject { // each agent runs their own server var llama: LlamaServer + private var backend: Backend init(id: String, prompt: String, systemPrompt: String, modelPath: String, contextLength: Int) { self.id = id self.prompt = prompt self.systemPrompt = systemPrompt - llama = LlamaServer(modelPath: modelPath, contextLength: contextLength) + self.llama = LlamaServer(modelPath: modelPath, contextLength: contextLength) + self.backend = LocalBackend(baseURL: BackendType.local.defaultURL, apiKey: nil) + } + + func createBackend(_ backend: BackendType, contextLength: Int, config: BackendConfig) { + let baseURL = config.baseURL ?? backend.defaultURL + + switch backend { + case .local: + self.backend = LocalBackend(baseURL: baseURL, apiKey: config.apiKey) + case .llama: + self.backend = LlamaBackend(baseURL: baseURL, apiKey: config.apiKey) + case .openai: + self.backend = OpenAIBackend(baseURL: baseURL, apiKey: config.apiKey) + case .ollama: + self.backend = OllamaBackend(baseURL: baseURL, apiKey: config.apiKey) + } } // this is the main loop of the agent // listen -> respond -> update mental model and save checkpoint // we respond before updating to avoid a long delay after user input - func listenThinkRespond( - speakerId: String, messages: [String], template: Template, temperature: Double? - ) async throws -> LlamaServer.CompleteResponse { - if status == .cold { - status = .coldProcessing - } else { - status = .processing - } - - prompt = template.run(systemPrompt: systemPrompt, messages: messages) - + func listenThinkRespond(speakerId: String, params: CompleteParams) async throws -> CompleteResponseSummary { + status = status == .cold ? .coldProcessing : .processing pendingMessage = "" - - let response = try await llama.complete( - prompt: prompt, stop: template.stopWords, temperature: temperature - ) { partialResponse in - DispatchQueue.main.async { - self.handleCompletionProgress(partialResponse: partialResponse) - } + for try await partialResponse in try await backend.complete(params: params) { + self.pendingMessage += partialResponse + self.prompt = pendingMessage } - - pendingMessage = response.text status = .ready - return response + return CompleteResponseSummary(text: pendingMessage, responseStartSeconds: 0) } func handleCompletionProgress(partialResponse: String) { @@ -66,13 +68,13 @@ class Agent: ObservableObject { func interrupt() async { if status != .processing, status != .coldProcessing { return } - await llama.interrupt() + await backend.interrupt() } func warmup() async throws { if prompt.isEmpty, systemPrompt.isEmpty { return } do { - _ = try await llama.complete(prompt: prompt, stop: nil, temperature: nil) + _ = try await backend.complete(params: CompleteParams(messages: [], model: "", numCTX: 2048, temperature: 0.7)) status = .ready } catch { status = .cold diff --git a/mac/FreeChat/Models/NPC/Backend.swift b/mac/FreeChat/Models/NPC/Backend.swift new file mode 100644 index 0000000..eb4ad97 --- /dev/null +++ b/mac/FreeChat/Models/NPC/Backend.swift @@ -0,0 +1,162 @@ +// +// Backend.swift +// FreeChat +// + +import Foundation +import EventSource + +protocol Backend: Actor, Sendable { + var type: BackendType { get } + var baseURL: URL { get } + var apiKey: String? { get } + var interrupted: Bool { get set } + + func complete(params: CompleteParams) async throws -> AsyncStream + func buildRequest(path: String, params: CompleteParams) -> URLRequest + func interrupt() async + + func listModels() async throws -> [String] +} + +extension Backend { + func complete(params: CompleteParams) async throws -> AsyncStream { + let request = buildRequest(path: "/v1/chat/completions", params: params) + self.interrupted = false + + return AsyncStream { continuation in + Task.detached { + let eventSource = EventSource() + let dataTask = eventSource.dataTask(for: request) + L: for await event in dataTask.events() { + guard await !self.interrupted else { break L } + switch event { + case .open: continue + case .error(let error): + print("EventSource server error:", error.localizedDescription) + break L + case .message(let message): + if let response = try CompleteResponse.from(data: message.data?.data(using: .utf8)), + let choice = response.choices.first { + if let content = choice.delta.content?.trimTrailingQuote() { continuation.yield(content) } + if choice.finishReason != nil { break L } + } + case .closed: break L + } + } + + continuation.finish() + } + } + } + + func interrupt() async { interrupted = true } + + func buildRequest(path: String, params: CompleteParams) -> URLRequest { + var request = URLRequest(url: baseURL.appendingPathComponent(path)) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("keep-alive", forHTTPHeaderField: "Connection") + request.setValue("Bearer: \(apiKey ?? "none")", forHTTPHeaderField: "Authorization") + request.httpBody = params.toJSON().data(using: .utf8) + + return request + } +} + +enum BackendType: String, CaseIterable { + case local = "This Computer (default)" + case llama = "Llama.cpp" + case openai = "OpenAI" + case ollama = "Ollama" + + var defaultURL: URL { + switch self { + case .local: return URL(string: "http://127.0.0.1:8690")! + case .llama: return URL(string: "http://127.0.0.1:8690")! + case .ollama: return URL(string: "http://127.0.0.1:11434")! + case .openai: return URL(string: "https://api.openai.com:443")! + } + } + + var howtoConfigure: AttributedString { + switch self { + case .local: try! AttributedString(markdown: NSLocalizedString("Runs on this computer offline using llama.cpp. No configuration required", comment: "No configuration")) + case .llama: try! AttributedString(markdown: NSLocalizedString("Llama.cpp is an efficient server than runs more than just LLaMa models. [Learn more](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md)", comment: "What it is and Usage link")) + case .openai: try! AttributedString(markdown: NSLocalizedString("Configure OpenAI's ChatGPT. [Learn more](https://openai.com/product)", comment: "What it is and Usage link")) + case .ollama: try! AttributedString(markdown: NSLocalizedString("Ollama runs large language models locally. [Learn more](https://ollama.com)", comment: "What it is and Usage link")) + } + } +} + +struct RoleMessage: Codable { + let role: String? + let content: String? +} + +struct CompleteParams: Encodable { + enum Mirostat: Int, Encodable { + case disabled = 0 + case v1 = 1 + case v2 = 2 + } + let messages: [RoleMessage] + let model: String + let mirostat: Mirostat = .disabled + let mirostatETA: Float = 0.1 + let mirostatTAU: Float = 5 + let numCTX: Int // 2048 + let numGQA = 1 + let numGPU: Int? = nil + let numThread: Int? = nil + let repeatLastN = 64 + let repeatPenalty: Float = 1.1 + let temperature: Float // 0.7 + let seed: Int? = nil + let stop: [String]? = nil + let tfsZ: Float? = nil + let numPredict = 128 + let topK = 40 + let topP: Float = 0.9 + let template: String? = nil + let cachePrompt = true + let stream = true + let keepAlive = true + + func toJSON() -> String { + let encoder = JSONEncoder() + encoder.keyEncodingStrategy = .convertToSnakeCase + let jsonData = try? encoder.encode(self) + return String(data: jsonData!, encoding: .utf8)! + } +} + +struct CompleteResponse: Decodable { + struct Choice: Decodable { + let index: Int + let delta: RoleMessage + let finishReason: String? + } + let id: String + let object: String + let created: Int + let model: String + let systemFingerprint: String? + let choices: [Choice] + + static func from(data: Data?) throws -> CompleteResponse? { + guard let data else { return nil } + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode(CompleteResponse.self, from: data) + } +} + +struct CompleteResponseSummary { + var text: String + var responseStartSeconds: Double + var predictedPerSecond: Double? + var modelName: String? + var nPredicted: Int? + } diff --git a/mac/FreeChat/Models/NPC/LlamaBackend.swift b/mac/FreeChat/Models/NPC/LlamaBackend.swift new file mode 100644 index 0000000..13df7b8 --- /dev/null +++ b/mac/FreeChat/Models/NPC/LlamaBackend.swift @@ -0,0 +1,50 @@ +// +// LlamaBackend.swift +// FreeChat +// + +import Foundation + +actor LlamaBackend: Backend { + var type: BackendType = .llama + var baseURL: URL + var apiKey: String? + var interrupted = false + + init(baseURL: URL, apiKey: String?) { + self.baseURL = baseURL + self.apiKey = apiKey + } + + deinit { interrupted = true } + + struct ModelListResponse: Decodable { + struct Model: Decodable { + struct Meta: Decodable { + let nCtxTrain: Int + let nEmbd: Int + let nParams: Int + let nVocab: Int + let size: Int + let vocabType: Int + } + let id: String + let created: Int + let meta: Meta + let object: String + } + let data: [Model] + + static func from(data: Data) throws -> ModelListResponse { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode(ModelListResponse.self, from: data) + } + } + + nonisolated func listModels() async throws -> [String] { + let url = await baseURL.appendingPathComponent("/v1/models") + let (data, _) = try await URLSession.shared.data(from: url) + return try ModelListResponse.from(data: data).data.compactMap({ $0.id.components(separatedBy: "/").last }) + } +} diff --git a/mac/FreeChat/Models/NPC/LlamaServer.swift b/mac/FreeChat/Models/NPC/LlamaServer.swift index eafd5ed..e2bef47 100644 --- a/mac/FreeChat/Models/NPC/LlamaServer.swift +++ b/mac/FreeChat/Models/NPC/LlamaServer.swift @@ -1,29 +1,7 @@ -@preconcurrency import EventSource import Foundation import SwiftUI import os.lock -func removeUnmatchedTrailingQuote(_ inputString: String) -> String { - var outputString = inputString - if inputString.last != "\"" { return outputString } - - // Count the number of quotes in the string - let countOfQuotes = outputString.reduce( - 0, - { (count, character) -> Int in - return character == "\"" ? count + 1 : count - }) - - // If there is an odd number of quotes, remove the last one - if countOfQuotes % 2 != 0 { - if let indexOfLastQuote = outputString.lastIndex(of: "\"") { - outputString.remove(at: indexOfLastQuote) - } - } - - return outputString -} - actor LlamaServer { var modelPath: String? @@ -35,31 +13,13 @@ actor LlamaServer { private var process = Process() private var serverUp = false private var serverErrorMessage = "" - private var eventSource: EventSource? - private let host: String - private let port: String - private let scheme: String - private var interrupted = false + private let url = URL(string: "http://127.0.0.1:8690")! private var monitor = Process() init(modelPath: String, contextLength: Int) { self.modelPath = modelPath self.contextLength = contextLength - self.scheme = "http" - self.host = "127.0.0.1" - self.port = "8690" - } - - init(contextLength: Int, tls: Bool, host: String, port: String) { - self.contextLength = contextLength - self.scheme = tls ? "https" : "http" - self.host = host - self.port = port - } - - private func url(_ path: String) -> URL { - URL(string: "\(scheme)://\(host):\(port)\(path)")! } // Start a monitor process that will terminate the server when our app dies. @@ -98,7 +58,7 @@ actor LlamaServer { print("started monitor for \(serverPID)") } - private func startServer() async throws { + func startServer() async throws { guard !process.isRunning, let modelPath = self.modelPath else { return } stopServer() process = Process() @@ -111,22 +71,18 @@ actor LlamaServer { "--model", modelPath, "--threads", "\(max(1, Int(ceil(Double(processes) / 3.0 * 2.0))))", "--ctx-size", "\(contextLength)", - "--port", port, + "--port", "8690", "--n-gpu-layers", gpu.available && useGPU ? "4" : "0", ] - print("starting llama.cpp server \(process.arguments!.joined(separator: " "))") process.standardInput = FileHandle.nullDevice - // To debug with server's output, comment these 2 lines to inherit stdout. process.standardOutput = FileHandle.nullDevice process.standardError = FileHandle.nullDevice try process.run() - try await waitForServer() - try startAppMonitor(serverPID: process.processIdentifier) let endTime = DispatchTime.now() @@ -138,127 +94,12 @@ actor LlamaServer { } func stopServer() { - if process.isRunning { - process.terminate() - } - if monitor.isRunning { - monitor.terminate() - } - } - - func complete( - prompt: String, stop: [String]?, temperature: Double?, - progressHandler: (@Sendable (String) -> Void)? = nil - ) async throws -> CompleteResponse { - #if DEBUG - print("START PROMPT\n \(prompt) \nEND PROMPT\n\n") - #endif - - let start = CFAbsoluteTimeGetCurrent() - try await startServer() - - // hit localhost for completion - var params = CompleteParams( - prompt: prompt, - stop: stop ?? [ - "", - "\n\(Message.USER_SPEAKER_ID):", - "\n\(Message.USER_SPEAKER_ID.lowercased()):", - "[/INST]", - "[INST]", - "USER:", - ] - ) - if let t = temperature { params.temperature = t } - - var request = URLRequest(url: url("/completion")) - - request.httpMethod = "POST" - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - request.setValue("text/event-stream", forHTTPHeaderField: "Accept") - request.setValue("keep-alive", forHTTPHeaderField: "Connection") - request.httpBody = params.toJSON().data(using: .utf8) - - // Use EventSource to receive server sent events - eventSource = EventSource(request: request) - eventSource!.connect() - - var response = "" - var responseDiff = 0.0 - var stopResponse: StopResponse? - listenLoop: for await event in eventSource!.events { - switch event { - case .open: - continue listenLoop - case .error(let error): - print("llama.cpp EventSource server error:", error.localizedDescription) - case .message(let message): - // parse json in message.data string then print the data.content value and append it to response - if let data = message.data?.data(using: .utf8) { - let decoder = JSONDecoder() - - do { - let responseObj = try decoder.decode(Response.self, from: data) - let fragment = responseObj.content - response.append(fragment) - progressHandler?(fragment) - if responseDiff == 0 { - responseDiff = CFAbsoluteTimeGetCurrent() - start - } - - if responseObj.stop { - do { - stopResponse = try decoder.decode(StopResponse.self, from: data) - } catch { - print("error decoding stopResponse", error as Any, data) - } - #if DEBUG - print( - "server.cpp stopResponse", - NSString(data: data, encoding: String.Encoding.utf8.rawValue) ?? "missing") - #endif - break listenLoop - } - } catch { - print("error decoding responseObj", error as Any, data) - break listenLoop - } - } - case .closed: - print("llama.cpp EventSource closed") - break listenLoop - } - } - - if responseDiff > 0 { - print("response: \(response)") - print("\n\n🦙 started response in \(responseDiff) seconds") - } - - // adding a trailing quote or space is a common mistake with the smaller model output - let cleanText = removeUnmatchedTrailingQuote(response).trimmingCharacters( - in: .whitespacesAndNewlines) - - let modelName = stopResponse?.model.split(separator: "/").last?.map { String($0) }.joined() - return CompleteResponse( - text: cleanText, - responseStartSeconds: responseDiff, - predictedPerSecond: stopResponse?.timings.predicted_per_second, - modelName: modelName, - nPredicted: stopResponse?.tokens_predicted - ) - } - - func interrupt() async { - if let eventSource, eventSource.readyState != .closed { - await eventSource.close() - } - interrupted = true + if process.isRunning { process.terminate() } + if monitor.isRunning { monitor.terminate() } } private func waitForServer() async throws { guard process.isRunning else { return } - interrupted = false serverErrorMessage = "" guard let modelPath = self.modelPath else { return } @@ -266,7 +107,7 @@ actor LlamaServer { modelPath.split(separator: "/").last?.map { String($0) }.joined() ?? "Unknown model name" let serverHealth = ServerHealth() - await serverHealth.updateURL(url("/health")) + await serverHealth.updateURL(url.appendingPathComponent("/health")) await serverHealth.check() var timeout = 60 @@ -287,68 +128,6 @@ actor LlamaServer { } } } - - struct CompleteResponse { - var text: String - var responseStartSeconds: Double - var predictedPerSecond: Double? - var modelName: String? - var nPredicted: Int? - } - - struct CompleteParams: Codable { - var prompt: String - var stop: [String] = [""] - var stream = true - var n_threads = 6 - - var n_predict = -1 - var temperature = DEFAULT_TEMP - var repeat_last_n = 128 // 0 = disable penalty, -1 = context size - var repeat_penalty = 1.18 // 1.0 = disabled - var top_k = 40 // <= 0 to use vocab size - var top_p = 0.95 // 1.0 = disabled - var tfs_z = 1.0 // 1.0 = disabled - var typical_p = 1.0 // 1.0 = disabled - var presence_penalty = 0.0 // 0.0 = disabled - var frequency_penalty = 0.0 // 0.0 = disabled - var mirostat = 0 // 0/1/2 - var mirostat_tau = 5 // target entropy - var mirostat_eta = 0.1 // learning rate - var cache_prompt = true - - func toJSON() -> String { - let encoder = JSONEncoder() - encoder.outputFormatting = .prettyPrinted - let jsonData = try? encoder.encode(self) - return String(data: jsonData!, encoding: .utf8)! - } - } - - struct Timings: Codable { - let prompt_n: Int - let prompt_ms: Double - let prompt_per_token_ms: Double - let prompt_per_second: Double? - - let predicted_n: Int - let predicted_ms: Double - let predicted_per_token_ms: Double - let predicted_per_second: Double? - } - - struct Response: Codable { - let content: String - let stop: Bool - } - - struct StopResponse: Codable { - let content: String - let model: String - let tokens_predicted: Int - let tokens_evaluated: Int - let timings: Timings - } } enum LlamaServerError: LocalizedError { diff --git a/mac/FreeChat/Models/NPC/LocalBackend.swift b/mac/FreeChat/Models/NPC/LocalBackend.swift new file mode 100644 index 0000000..7298c60 --- /dev/null +++ b/mac/FreeChat/Models/NPC/LocalBackend.swift @@ -0,0 +1,27 @@ +// +// LocalBackend.swift +// FreeChat +// + +import Foundation + +actor LocalBackend: Backend { + var type: BackendType = .local + var baseURL: URL + var apiKey: String? + var interrupted = false + + init(baseURL: URL, apiKey: String?) { + self.baseURL = baseURL + self.apiKey = apiKey + } + + deinit { interrupted = true } + + func listModels() async throws -> [String] { + let req = Model.fetchRequest() + req.sortDescriptors = [NSSortDescriptor(key: "size", ascending: true)] + let context = PersistenceController.shared.container.newBackgroundContext() + return try context.fetch(req).compactMap({ $0.url?.lastPathComponent }) + } +} diff --git a/mac/FreeChat/Models/NPC/OllamaBackend.swift b/mac/FreeChat/Models/NPC/OllamaBackend.swift new file mode 100644 index 0000000..43e4fa8 --- /dev/null +++ b/mac/FreeChat/Models/NPC/OllamaBackend.swift @@ -0,0 +1,52 @@ +// +// OllamaBackend.swift +// FreeChat +// + +import Foundation + +actor OllamaBackend: Backend { + var type: BackendType = .ollama + var baseURL: URL + var apiKey: String? + var interrupted = false + + init(baseURL: URL, apiKey: String?) { + self.baseURL = baseURL + self.apiKey = apiKey + } + + deinit { interrupted = true } + + struct TagsResponse: Decodable { + struct Model: Decodable { + struct Details: Decodable { + let parentModel: String? + let format: String + let family: String + let families: [String]? + let parameterSize: String + let quantizationLevel: String + } + let name: String + let model: String + let modifiedAt: String + let size: Int + let digest: String + let details: Details + } + let models: [Model] + + static func from(data: Data) throws -> TagsResponse { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode(TagsResponse.self, from: data) + } + } + + nonisolated func listModels() async throws -> [String] { + let url = await baseURL.appendingPathComponent("/api/tags") + let (data, _) = try await URLSession.shared.data(from: url) + return try TagsResponse.from(data: data).models.map({ $0.name }) + } +} diff --git a/mac/FreeChat/Models/NPC/OpenAIBackend.swift b/mac/FreeChat/Models/NPC/OpenAIBackend.swift new file mode 100644 index 0000000..b73d99d --- /dev/null +++ b/mac/FreeChat/Models/NPC/OpenAIBackend.swift @@ -0,0 +1,43 @@ +// +// OpenAIBackend.swift +// FreeChat +// + +import Foundation + +actor OpenAIBackend: Backend { + var type: BackendType = .openai + let baseURL: URL + let apiKey: String? + var interrupted = false + + init(baseURL: URL, apiKey: String?) { + self.baseURL = baseURL + self.apiKey = apiKey + } + + deinit { interrupted = true } + + nonisolated func listModels() -> [String] { + [ + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4-1106-vision-preview", + "gpt-4", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0613", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-instruct", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "babbage-002", + "davinci-002", + ] + } +} diff --git a/mac/FreeChat/Models/NPC/ServerHealth.swift b/mac/FreeChat/Models/NPC/ServerHealth.swift index 1c294e9..0d06ae2 100644 --- a/mac/FreeChat/Models/NPC/ServerHealth.swift +++ b/mac/FreeChat/Models/NPC/ServerHealth.swift @@ -15,16 +15,11 @@ fileprivate struct ServerHealthRequest { let config = URLSessionConfiguration.default config.timeoutIntervalForRequest = 3 config.timeoutIntervalForResource = 1 - let (data, response) = try await URLSession(configuration: config).data(from: url) - guard let responseCode = (response as? HTTPURLResponse)?.statusCode, - responseCode > 0 + let (_, response) = try await URLSession(configuration: config).data(from: url) + guard let code = (response as? HTTPURLResponse)?.statusCode, code > 0 else { throw ServerHealthError.invalidResponse } - guard let json = try JSONSerialization.jsonObject(with: data, options: []) as? [String: Any], - let jsonStatus: String = json["status"] as? String - else { throw ServerHealthError.invalidResponse } - - return responseCode == 200 && jsonStatus == "ok" + return code == 200 } } diff --git a/mac/FreeChat/Models/NPC/String+TrimQuotes.swift b/mac/FreeChat/Models/NPC/String+TrimQuotes.swift new file mode 100644 index 0000000..cc91420 --- /dev/null +++ b/mac/FreeChat/Models/NPC/String+TrimQuotes.swift @@ -0,0 +1,23 @@ +// +// String+TrimQuotes.swift +// FreeChat +// + +import Foundation + +extension String { + func trimTrailingQuote() -> String { + guard self.last == "\"" else { return self } + + // Count the number of quotes in the string + let countOfQuotes = self.filter({ $0 == "\"" }).count + guard countOfQuotes % 2 != 0 else { return self } + var outputString = self + // If there is an odd number of quotes, remove the last one + if let indexOfLastQuote = self.lastIndex(of: "\"") { + outputString.remove(at: indexOfLastQuote) + } + + return outputString + } +} diff --git a/mac/FreeChat/Views/ConversationView/ConversationView.swift b/mac/FreeChat/Views/ConversationView/ConversationView.swift index 08d6d29..5520ac1 100644 --- a/mac/FreeChat/Views/ConversationView/ConversationView.swift +++ b/mac/FreeChat/Views/ConversationView/ConversationView.swift @@ -13,20 +13,13 @@ struct ConversationView: View, Sendable { @Environment(\.managedObjectContext) private var viewContext @EnvironmentObject private var conversationManager: ConversationManager + @AppStorage("backendTypeID") private var backendTypeID: String? @AppStorage("selectedModelId") private var selectedModelId: String? @AppStorage("systemPrompt") private var systemPrompt: String = DEFAULT_SYSTEM_PROMPT @AppStorage("contextLength") private var contextLength: Int = DEFAULT_CONTEXT_LENGTH @AppStorage("playSoundEffects") private var playSoundEffects = true - @AppStorage("temperature") private var temperature: Double? @AppStorage("useGPU") private var useGPU: Bool = DEFAULT_USE_GPU - @AppStorage("serverHost") private var serverHost: String? - @AppStorage("serverPort") private var serverPort: String? - @AppStorage("serverTLS") private var serverTLS: Bool? - - @FetchRequest( - sortDescriptors: [NSSortDescriptor(keyPath: \Model.size, ascending: true)], - animation: .default) - private var models: FetchedResults + @AppStorage("temperature") private var temperature: Double = DEFAULT_TEMP private static let SEND = NSDataAsset(name: "ESM_Perfect_App_Button_2_Organic_Simple_Classic_Game_Click") private static let PING = NSDataAsset(name: "ESM_POWER_ON_SYNTH") @@ -41,15 +34,6 @@ struct ConversationView: View, Sendable { conversationManager.agent } - var selectedModel: Model? { - if selectedModelId != AISettingsView.remoteModelOption, - let selectedModelId = self.selectedModelId { - models.first(where: { $0.id?.uuidString == selectedModelId }) - } else { - models.first - } - } - @State var pendingMessage: Message? @State var messages: [Message] = [] @@ -75,8 +59,8 @@ struct ConversationView: View, Sendable { if m == pendingMessage { MessageView(pendingMessage!, overrideText: pendingMessageText, agentStatus: agent.status) .onAppear { - scrollToLastIfRecent(proxy) - } + scrollToLastIfRecent(proxy) + } .opacity(showResponse ? 1 : 0) .animation(.interpolatingSpring(stiffness: 170, damping: 20), value: showResponse) .id("\(m.id)\(m.updatedAt as Date?)") @@ -91,39 +75,44 @@ struct ConversationView: View, Sendable { } } } - .padding(.vertical, 12) - .onReceive( + .padding(.vertical, 12) + .onReceive( agent.$pendingMessage.throttle(for: .seconds(0.1), scheduler: RunLoop.main, latest: true) ) { text in pendingMessageText = text } - .onReceive( + .onReceive( agent.$pendingMessage.throttle(for: .seconds(0.2), scheduler: RunLoop.main, latest: true) ) { _ in DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { autoScroll(proxy) } } + .onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("backendTypeIDDidChange"))) { _ in + initializeBackends() + } } - .textSelection(.enabled) - .safeAreaInset(edge: .bottom, spacing: 0) { + .textSelection(.enabled) + .safeAreaInset(edge: .bottom, spacing: 0) { MessageTextField { s in - submit(s) + Task { + await submit(s) + } } } - .frame(maxWidth: .infinity) - .onAppear { showConversation(conversation) } - .onChange(of: conversation) { nextConvo in showConversation(nextConvo) } - .onChange(of: selectedModelId) { showConversation(conversation, modelId: $0) } - .navigationTitle(conversation.titleWithDefault) - .alert(isPresented: $showErrorAlert, error: llamaError) { _ in + .frame(maxWidth: .infinity) + .onAppear { showConversation(conversation) } + .onChange(of: conversation) { nextConvo in showConversation(nextConvo) } + .onChange(of: selectedModelId) { showConversation(conversation, modelId: $0) } + .navigationTitle(conversation.titleWithDefault) + .alert(isPresented: $showErrorAlert, error: llamaError) { _ in Button("OK") { llamaError = nil } } message: { error in Text(error.recoverySuggestion ?? "") } - .background(Color.textBackground) + .background(Color.textBackground) } private func playSendSound() { @@ -139,46 +128,39 @@ struct ConversationView: View, Sendable { } private func showConversation(_ c: Conversation, modelId: String? = nil) { - guard - let selectedModelId = modelId ?? self.selectedModelId, - !selectedModelId.isEmpty - else { return } - messages = c.orderedMessages - - - // warmup the agent if it's cold or model has changed - Task { - if selectedModelId == AISettingsView.remoteModelOption { - await initializeServerRemote() - } else { - await initializeServerLocal(modelId: selectedModelId) - } + initializeBackends() + } + + private func initializeBackends() { + let backendType: BackendType = BackendType(rawValue: backendTypeID ?? "") ?? .local + if backendType == .local { + Task { try? await initializeBackendLocal() } } + + do { + let config = try fetchBackendConfig(context: viewContext) ?? BackendConfig(context: viewContext) + agent.createBackend(backendType, contextLength: contextLength, config: config) + + } catch { print("error fetching backend config", error) } } - private func initializeServerLocal(modelId: String) async { - guard let id = UUID(uuidString: modelId) + private func initializeBackendLocal() async throws { + guard let selectedModelId, !selectedModelId.isEmpty, + let id = UUID(uuidString: selectedModelId) else { return } let llamaPath = await agent.llama.modelPath let req = Model.fetchRequest() req.predicate = NSPredicate(format: "id == %@", id as CVarArg) - if let model = try? viewContext.fetch(req).first, + guard let model = try viewContext.fetch(req).first, let modelPath = model.url?.path(percentEncoded: false), - modelPath != llamaPath { - await agent.llama.stopServer() - agent.llama = LlamaServer(modelPath: modelPath, contextLength: contextLength) - } - } - - private func initializeServerRemote() async { - guard let tls = serverTLS, - let host = serverHost, - let port = serverPort + modelPath != llamaPath else { return } + await agent.llama.stopServer() - agent.llama = LlamaServer(contextLength: contextLength, tls: tls, host: host, port: port) + agent.llama = LlamaServer(modelPath: modelPath, contextLength: contextLength) + try await agent.llama.startServer() } private func scrollToLastIfRecent(_ proxy: ScrollViewProxy) { @@ -218,23 +200,12 @@ struct ConversationView: View, Sendable { showErrorAlert = true } - func submit(_ input: String) { + func submit(_ input: String) async { if (agent.status == .processing || agent.status == .coldProcessing) { - Task { - await agent.interrupt() - - Task.detached(priority: .userInitiated) { - try? await Task.sleep(for: .seconds(1)) - await submit(input) - } - } - return + await agent.interrupt() } playSendSound() - - guard let model = selectedModel else { return } - showUserMessage = false engageAutoScroll() @@ -252,8 +223,6 @@ struct ConversationView: View, Sendable { showUserMessage = true } - let messageTexts = messages.map { $0.text ?? "" } - // Pending message for bot's reply let m = Message(context: viewContext) m.fromId = agent.id @@ -280,49 +249,59 @@ struct ConversationView: View, Sendable { } } - Task { - var response: LlamaServer.CompleteResponse + let response: CompleteResponseSummary + do { + let config = try fetchBackendConfig(context: viewContext) + let messages = [RoleMessage(role: "system", content: systemPrompt)] + + messages.compactMap({ $0.text }).map({ RoleMessage(role: "user", content: $0) }) + let params = CompleteParams(messages: messages, + model: config?.model ?? Model.defaultModelUrl.deletingPathExtension().lastPathComponent, + numCTX: contextLength, + temperature: Float(temperature)) + response = try await agent.listenThinkRespond(speakerId: Message.USER_SPEAKER_ID, params: params) + } catch let error as LlamaServerError { + handleResponseError(error) + return + } catch { + print("agent listen threw unexpected error", error as Any) + return + } + + await MainActor.run { + m.text = response.text + m.predictedPerSecond = response.predictedPerSecond ?? -1 + m.responseStartSeconds = response.responseStartSeconds + m.nPredicted = Int64(response.nPredicted ?? -1) + m.modelName = response.modelName + m.updatedAt = Date() + + playReceiveSound() do { - response = try await agent.listenThinkRespond(speakerId: Message.USER_SPEAKER_ID, messages: messageTexts, template: model.template, temperature: temperature) - } catch let error as LlamaServerError { - handleResponseError(error) - return + try viewContext.save() } catch { - print("agent listen threw unexpected error", error as Any) - return + print("error creating message", error.localizedDescription) } - await MainActor.run { - m.text = response.text - m.predictedPerSecond = response.predictedPerSecond ?? -1 - m.responseStartSeconds = response.responseStartSeconds - m.nPredicted = Int64(response.nPredicted ?? -1) - m.modelName = response.modelName - m.updatedAt = Date() - - playReceiveSound() - do { - try viewContext.save() - } catch { - print("error creating message", error.localizedDescription) - } - - if pendingMessage?.text != nil, - !pendingMessage!.text!.isEmpty, - response.text.hasPrefix(agent.pendingMessage), - m == pendingMessage { - pendingMessage = nil - agent.pendingMessage = "" - } - - if conversation != agentConversation { - return - } + if pendingMessage?.text != nil, + !pendingMessage!.text!.isEmpty, + response.text.hasPrefix(agent.pendingMessage), + m == pendingMessage { + pendingMessage = nil + agent.pendingMessage = "" + } + if conversation == agentConversation { messages = agentConversation.orderedMessages } } } + + private func fetchBackendConfig(context: NSManagedObjectContext) throws -> BackendConfig? { + let backendType: BackendType = BackendType(rawValue: backendTypeID ?? "") ?? .local + let req = BackendConfig.fetchRequest() + req.predicate = NSPredicate(format: "backendType == %@", backendType.rawValue) + return try context.fetch(req).first + } } #Preview { diff --git a/mac/FreeChat/Views/Settings/AISettingsView.swift b/mac/FreeChat/Views/Settings/AISettingsView.swift index 15e9676..5d061f4 100644 --- a/mac/FreeChat/Views/Settings/AISettingsView.swift +++ b/mac/FreeChat/Views/Settings/AISettingsView.swift @@ -11,8 +11,6 @@ import SwiftUI struct AISettingsView: View { static let title = "Intelligence" private static let customizeModelsId = "customizeModels" - static let remoteModelOption = "remoteModelOption" - private let serverHealthTimer = Timer.publish(every: 3, on: .main, in: .common).autoconnect() @Environment(\.managedObjectContext) private var viewContext @@ -23,29 +21,30 @@ struct AISettingsView: View { animation: .default) private var models: FetchedResults - @AppStorage("selectedModelId") private var selectedModelId: String? + @AppStorage("backendTypeID") private var backendTypeID: String = BackendType.local.rawValue + @AppStorage("selectedModelId") private var selectedModelId: String? // Local only @AppStorage("systemPrompt") private var systemPrompt = DEFAULT_SYSTEM_PROMPT @AppStorage("contextLength") private var contextLength = DEFAULT_CONTEXT_LENGTH @AppStorage("temperature") private var temperature: Double = DEFAULT_TEMP @AppStorage("useGPU") private var useGPU = DEFAULT_USE_GPU - @AppStorage("serverTLS") private var serverTLS: Bool = false - @AppStorage("serverHost") private var serverHost: String? - @AppStorage("serverPort") private var serverPort: String? - @AppStorage("remoteModelTemplate") var remoteModelTemplate: String? + @AppStorage("openAIToken") private var openAIToken: String? @State var pickedModel: String? // Picker selection @State var customizeModels = false // Show add remove models - @State var editRemoteModel = false // Show remote model server @State var editSystemPrompt = false @State var editFormat = false @State var revealAdvanced = false - @State var inputServerTLS: Bool = false - @State var inputServerHost: String = "" - @State var inputServerPort: String = "" + @State var serverTLS: Bool = false + @State var serverHost: String = "" + @State var serverPort: String = "" + @State var serverAPIKey: String = "" @State var serverHealthScore: Double = -1 + @State var modelList: [String] = [] @StateObject var gpu = GPU.shared + private var isUsingLocalServer: Bool { backendTypeID == BackendType.local.rawValue } + let contextLengthFormatter: NumberFormatter = { let formatter = NumberFormatter() formatter.minimum = 1 @@ -58,14 +57,16 @@ struct AISettingsView: View { formatter.minimum = 0 return formatter }() - + var selectedModel: Model? { - if let selectedModelId = self.selectedModelId { + if let selectedModelId { models.first(where: { $0.id?.uuidString == selectedModelId }) } else { models.first } } + + var selectedModelName: String? { modelList.first } var systemPromptEditor: some View { VStack { @@ -90,55 +91,89 @@ struct AISettingsView: View { } } + var backendTypePicker: some View { + VStack(alignment: .leading) { + Picker("Backend", selection: $backendTypeID) { + ForEach(BackendType.allCases, id: \.self) { name in + Text(name.rawValue).tag(name.rawValue) + } + } + .onChange(of: backendTypeID) { + Task { + do { try await loadBackendConfig() } + catch let error { print("error fetching models:", error) } + } + NotificationCenter.default.post(name: NSNotification.Name("backendTypeIDDidChange"), object: $0) + } + Text(BackendType(rawValue: backendTypeID)?.howtoConfigure ?? "") + .font(.callout) + .foregroundColor(Color(NSColor.secondaryLabelColor)) + .lineLimit(5) + .fixedSize(horizontal: false, vertical: true) + .padding(.top, 0.5) + } + } + + @available(*, deprecated, message: "template is not supported") + var editPromptFormat: some View { + HStack { + Text("Prompt format \(selectedModel?.template.format.rawValue ?? "")") + .foregroundColor(Color(NSColor.secondaryLabelColor)) + .font(.caption) + Button("Edit") { + editFormat = true + } + .buttonStyle(.link).font(.caption) + .offset(x: -4) + } + .sheet(isPresented: $editFormat) { + if let model = selectedModelId { + EditFormat(modelName: model) + } else if !isUsingLocalServer { + EditFormat(modelName: "Remote") + } + } + } + var modelPicker: some View { VStack(alignment: .leading) { Picker("Model", selection: $pickedModel) { - ForEach(models) { i in - if let url = i.url { - Text(i.name ?? url.lastPathComponent) - .tag(i.id?.uuidString) - .help(url.path) - } + ForEach(modelList, id: \.self) { + Text($0) + .tag($0 as String?) + .help($0) } - - Divider().tag(nil as String?) - Text("Remote Model (Advanced)").tag(AISettingsView.remoteModelOption as String?) - Text("Add or Remove Models...").tag(AISettingsView.customizeModelsId as String?) - }.onReceive(Just(pickedModel)) { _ in - switch pickedModel { - case AISettingsView.customizeModelsId: - customizeModels = true - editRemoteModel = false - case AISettingsView.remoteModelOption: - customizeModels = false - editRemoteModel = true - selectedModelId = AISettingsView.remoteModelOption - case .some(let pickedModelValue): - customizeModels = false - editRemoteModel = false - selectedModelId = pickedModelValue - default: break + if isUsingLocalServer { + Divider().tag(nil as String?) + Text("Add or Remove Models...").tag(AISettingsView.customizeModelsId as String?) } } - .onChange(of: pickedModel) { newValue in - switch pickedModel { - case AISettingsView.customizeModelsId: + .disabled(backendTypeID == BackendType.llama.rawValue || modelList.isEmpty) + .onReceive(Just(pickedModel)) { _ in + if pickedModel == AISettingsView.customizeModelsId { customizeModels = true - editRemoteModel = false - case AISettingsView.remoteModelOption: - customizeModels = false - editRemoteModel = true - selectedModelId = AISettingsView.remoteModelOption - case .some(let pickedModelValue): - customizeModels = false - editRemoteModel = false - selectedModelId = pickedModelValue - default: break } + } + .onChange(of: pickedModel) { newValue in + guard newValue != AISettingsView.customizeModelsId else { return } + if let backendType: BackendType = BackendType(rawValue: backendTypeID) { + do { + if backendType == .local, + let model = models.filter({ $0.id?.uuidString == newValue }).first { + selectedModelId = model.id?.uuidString + pickedModel = model.name + } + let config = try findOrCreateBackendConfig(backendType, context: viewContext) + config.backendType = backendType.rawValue + config.model = pickedModel // newValue could be ID + try viewContext.save() + } + catch { print("error saving backend config:", error) } + } } - if !editRemoteModel { + if isUsingLocalServer { Text( "The default model is general purpose, small, and works on most computers. Larger models are slower but wiser. Some models specialize in certain tasks like coding Python. FreeChat is compatible with most models in GGUF format. [Find new models](https://huggingface.co/models?search=GGUF)" ) @@ -148,56 +183,21 @@ struct AISettingsView: View { .fixedSize(horizontal: false, vertical: true) .padding(.top, 0.5) } - - HStack { - if let model = selectedModel { - Text("Prompt format: \(model.template.format.rawValue)") - .foregroundColor(Color(NSColor.secondaryLabelColor)) - .font(.caption) - } else if editRemoteModel { - Text("Prompt format: \(remoteModelTemplate ?? TemplateFormat.vicuna.rawValue)") - .foregroundColor(Color(NSColor.secondaryLabelColor)) - .font(.caption) - } - Button("Edit") { - editFormat = true - } - .buttonStyle(.link).font(.caption) - .offset(x: -4) - } - .sheet( - isPresented: $editFormat, - content: { - if let model = selectedModel { - EditFormat(model: model) - } else if editRemoteModel { - EditFormat(modelName: "Remote") - } - }) } } - var hasRemoteServerInputChanged: Bool { - inputServerHost != serverHost || inputServerPort != serverPort || inputServerTLS != serverTLS - } var hasRemoteConnectionError: Bool { serverHealthScore < 0.25 && serverHealthScore >= 0 } var indicatorColor: Color { switch serverHealthScore { - case 0..<0.25: - Color(red: 1, green: 0, blue: 0) - case 0.25..<0.5: - Color(red: 1, green: 0.5, blue: 0) - case 0.5..<0.75: - Color(red: 0.45, green: 0.55, blue: 0) - case 0.75..<0.95: - Color(red: 0.1, green: 0.9, blue: 0) - case 0.95...1: - Color(red: 0, green: 1, blue: 0) - default: - Color(red: 0.5, green: 0.5, blue: 0.5) + case 0..<0.25: Color(red: 1, green: 0, blue: 0) + case 0.25..<0.5: Color(red: 1, green: 0.5, blue: 0) + case 0.5..<0.75: Color(red: 0.45, green: 0.55, blue: 0) + case 0.75..<0.95: Color(red: 0.1, green: 0.9, blue: 0) + case 0.95...1: Color(red: 0, green: 1, blue: 0) + default: Color(red: 0.5, green: 0.5, blue: 0.5) } } @@ -209,12 +209,9 @@ struct AISettingsView: View { .foregroundColor(indicatorColor) Group { switch serverHealthScore { - case 0.25...1: - Text("Connected") - case 0..<0.25: - Text("Connection Error. Retrying...") - default: - Text("Not Connected") + case 0.25...1: Text("Connected") + case 0..<0.25: Text("Connection Error. Retrying...") + default: Text("Not Connected") } } .font(.callout) @@ -230,34 +227,30 @@ struct AISettingsView: View { } } - var sectionRemoteModel: some View { + var sectionRemoteBackend: some View { Group { - Text( - "If you have access to a powerful server, you may want to run your model there. Enter the host and port to connect to a remote llama.cpp server. Instructions for running the server can be found [here](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md)" - ) - .font(.callout) - .foregroundColor(Color(NSColor.secondaryLabelColor)) - .lineLimit(5) - .fixedSize(horizontal: false, vertical: true) - .padding(.top, 0.5) HStack { - TextField("Server host", text: $inputServerHost, prompt: Text("yourserver.net")) + TextField("Server host", text: $serverHost, prompt: Text("yourserver.net")) .textFieldStyle(.plain) .font(.callout) - TextField("Server port", text: $inputServerPort, prompt: Text("3000")) + TextField("Server port", text: $serverPort, prompt: Text("8690")) .textFieldStyle(.plain) .font(.callout) Spacer() } - Toggle(isOn: $inputServerTLS) { + Toggle(isOn: $serverTLS) { Text("Secure connection (HTTPS)") .font(.callout) } + HStack { + SecureField("API Key", text: $serverAPIKey) + .textFieldStyle(.plain) + .font(.callout) + } HStack { serverHealthIndication Spacer() - Button("Apply", action: saveFormRemoteServer) - .disabled(!hasRemoteServerInputChanged && !hasRemoteConnectionError) + Button("Apply", action: saveFormRemoteBackend) } } } @@ -266,9 +259,10 @@ struct AISettingsView: View { Form { Section { systemPromptEditor + backendTypePicker modelPicker - if editRemoteModel { - sectionRemoteModel + if !isUsingLocalServer { + sectionRemoteBackend } } Section { @@ -277,7 +271,7 @@ struct AISettingsView: View { content: { VStack(alignment: .leading) { HStack { - Text("Configure llama.cpp based on the model you're using.") + Text("Configure your backend based on the model you're using.") .foregroundColor(Color(NSColor.secondaryLabelColor)) Button("Restore defaults") { contextLength = DEFAULT_CONTEXT_LENGTH @@ -288,20 +282,16 @@ struct AISettingsView: View { .padding(.top, 2.5) .padding(.bottom, 4) - if !editRemoteModel { - Divider() - - HStack { - Text("Context Length") - TextField("", value: $contextLength, formatter: contextLengthFormatter) - .padding(.vertical, -8) - .padding(.trailing, -10) - } - .padding(.top, 0.5) + Divider() + HStack { + Text("Context Length") + TextField("", value: $contextLength, formatter: contextLengthFormatter) + .padding(.vertical, -8) + .padding(.trailing, -10) } + .padding(.top, 0.5) Divider() - HStack { Text("Temperature") Slider(value: $temperature, in: 0...2, step: 0.1).offset(y: 1) @@ -311,7 +301,7 @@ struct AISettingsView: View { .frame(width: 24, alignment: .trailing) }.padding(.top, 1) - if gpu.available && !editRemoteModel { + if gpu.available && isUsingLocalServer { Divider() Toggle("Use GPU Acceleration", isOn: $useGPU).padding(.top, 1) @@ -333,84 +323,140 @@ struct AISettingsView: View { } } .formStyle(.grouped) - .sheet(isPresented: $customizeModels, onDismiss: { pickedModel = selectedModelId }) { + .sheet(isPresented: $customizeModels, onDismiss: { setPickedModelFromID(modelID: selectedModelId) }) { EditModels(selectedModelId: $selectedModelId) } .sheet(isPresented: $editSystemPrompt) { EditSystemPrompt() } - .onSubmit(saveFormRemoteServer) + .onSubmit(saveFormRemoteBackend) .navigationTitle(AISettingsView.title) .onAppear { - if selectedModelId != AISettingsView.remoteModelOption { - let selectedModelExists = - models - .compactMap({ $0.id?.uuidString }) - .contains(selectedModelId) - if !selectedModelExists { - selectedModelId = models.first?.id?.uuidString - } + Task { + do { try await loadBackendConfig() } + catch let error { print("error fetching models:", error) } } - pickedModel = selectedModelId - - inputServerTLS = serverTLS - inputServerHost = serverHost ?? "" - inputServerPort = serverPort ?? "" - updateRemoteServerURL() } - .onChange(of: selectedModelId) { newModelId in - pickedModel = newModelId - guard - let model = models.first(where: { $0.id?.uuidString == newModelId }) ?? models.first - else { return } - - conversationManager.rebootAgent( - systemPrompt: self.systemPrompt, model: model, viewContext: viewContext) + .onChange(of: selectedModelId) { _ in + Task { + try? await fetchModels(backendType: .local) + setPickedModelFromID(modelID: selectedModelId) + } + if isUsingLocalServer { rebootAgentWithSelectedModel() } } - .onChange(of: systemPrompt) { nextPrompt in - guard let model: Model = selectedModel else { return } - conversationManager.rebootAgent( - systemPrompt: nextPrompt, model: model, viewContext: viewContext) + .onChange(of: systemPrompt) { _ in + if isUsingLocalServer { rebootAgentWithSelectedModel() } } - .onChange(of: useGPU) { nextUseGPU in - guard let model: Model = selectedModel else { return } - conversationManager.rebootAgent( - systemPrompt: self.systemPrompt, model: model, viewContext: viewContext) + .onChange(of: useGPU) { _ in + if isUsingLocalServer { rebootAgentWithSelectedModel() } } .onReceive( - NotificationCenter.default.publisher(for: NSNotification.Name("selectedModelDidChange")) + NotificationCenter.default.publisher(for: NSNotification.Name("selectedLocalModelDidChange")) ) { output in - if let updatedId: String = output.object as? String { - selectedModelId = updatedId - } + setPickedModelFromID(modelID: output.object as? String) } .frame( minWidth: 300, maxWidth: 600, minHeight: 184, idealHeight: 195, maxHeight: 400, alignment: .center) } - private func saveFormRemoteServer() { - serverTLS = inputServerTLS - serverHost = inputServerHost - serverPort = inputServerPort + private func saveFormRemoteBackend() { + guard let backendType: BackendType = BackendType(rawValue: backendTypeID), + let config = try? findOrCreateBackendConfig(backendType, context: viewContext), + let url = URL(string: "\(serverTLS && config.baseURL != nil ? "https" : "http")://\(serverHost):\(serverPort)") // Default to TLS disabled + else { return } + serverHealthScore = -1 - updateRemoteServerURL() - - selectedModelId = AISettingsView.remoteModelOption - } + config.apiKey = serverAPIKey + config.baseURL = url + if modelList.contains(pickedModel ?? "") { config.model = pickedModel } + do { try viewContext.save() } + catch { print("error saving backend", error) } - private func updateRemoteServerURL() { - let scheme = inputServerTLS ? "https" : "http" - guard let url = URL(string: "\(scheme)://\(inputServerHost):\(inputServerPort)/health") - else { return } + serverTLS = config.baseURL?.scheme == "https" // Match the UI value Task { + if modelList.isEmpty { try? await fetchModels(backendType: backendType) } await ServerHealth.shared.updateURL(url) await ServerHealth.shared.check() } } + + // MARK: - Fetch models + + private func fetchModels(backendType: BackendType) async throws { + let baseURL = URL(string: "\(serverTLS ? "https" : "http")://\(serverHost):\(serverPort)") ?? backendType.defaultURL + modelList.removeAll() + + switch backendType { + case .local: + let baseURL = BackendType.local.defaultURL + let backend = LocalBackend(baseURL: baseURL, apiKey: nil) + modelList = try await backend.listModels() + case .llama: + modelList = ["Unavailable"] + case .openai: + let backend = OpenAIBackend(baseURL: baseURL, apiKey: nil) + modelList = backend.listModels() + case .ollama: + let backend = OllamaBackend(baseURL: baseURL, apiKey: nil) + modelList = try await backend.listModels() + } + + if !modelList.contains(pickedModel ?? "") { pickedModel = modelList.first } + } + + private func rebootAgentWithSelectedModel() { + guard let selectedModelId else { return } + let req = Model.fetchRequest() + req.predicate = NSPredicate(format: "id == %@", selectedModelId) + do { + if let model = try viewContext.fetch(req).first { + conversationManager.rebootAgent(systemPrompt: self.systemPrompt, model: model, viewContext: viewContext) + } + } catch { print("error fetching model id:", selectedModelId, error) } + } + + private func setPickedModelFromID(modelID: String?) { + guard let model = models.filter({ $0.id?.uuidString == modelID }).first + else { return } + selectedModelId = model.id?.uuidString + pickedModel = model.name + backendTypeID = BackendType.local.rawValue + } + + // MARK: - Backend config + + private func loadBackendConfig() async throws { + let backendType: BackendType = BackendType(rawValue: backendTypeID) ?? .local + let config = try findOrCreateBackendConfig(backendType, context: viewContext) + if backendType == .local { + let model = models.first(where: { $0.id?.uuidString == selectedModelId }) ?? models.first + config.model = model?.name + selectedModelId = model?.id?.uuidString + } + + if config.baseURL == nil { config.baseURL = backendType.defaultURL } + serverTLS = config.baseURL?.scheme == "https" ? true : false + serverHost = config.baseURL?.host() ?? "" + serverPort = "\(config.baseURL?.port ?? 8690)" + serverAPIKey = config.apiKey ?? "" + + try await fetchModels(backendType: backendType) + config.model = config.model ?? modelList.first + pickedModel = config.model + try viewContext.save() + + await ServerHealth.shared.updateURL(config.baseURL) + } + + private func findOrCreateBackendConfig(_ backendType: BackendType, context: NSManagedObjectContext) throws -> BackendConfig { + let req = BackendConfig.fetchRequest() + req.predicate = NSPredicate(format: "backendType == %@", backendType.rawValue) + return try context.fetch(req).first ?? BackendConfig(context: context) + } } #Preview{ - AISettingsView(inputServerTLS: true) + AISettingsView() .environment(\.managedObjectContext, PersistenceController.preview.container.viewContext) }