Skip to content
Open

wip #138

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions TypeaheadAI.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
2B27450A2AB01CF400F37D3E /* SpecialSaveActor.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B2745092AB01CF400F37D3E /* SpecialSaveActor.swift */; };
2B27450E2AB0380C00F37D3E /* AppContextManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B27450D2AB0380C00F37D3E /* AppContextManager.swift */; };
2B2745102AB03A3D00F37D3E /* CanSimulateCopy.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B27450F2AB03A3D00F37D3E /* CanSimulateCopy.swift */; };
2B285D852ACA22FB000C5BDE /* LaunchAtLogin in Frameworks */ = {isa = PBXBuildFile; productRef = 2B285D842ACA22FB000C5BDE /* LaunchAtLogin */; };
2B2EF14E2AC17D4000EF2BD4 /* CustomTextField.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B2EF14D2AC17D4000EF2BD4 /* CustomTextField.swift */; };
2B2EF1502AC40C8F00EF2BD4 /* ChatBubble.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B2EF14F2AC40C8F00EF2BD4 /* ChatBubble.swift */; };
2B2EF1522AC40CB500EF2BD4 /* MessagePendingView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B2EF1512AC40CB500EF2BD4 /* MessagePendingView.swift */; };
Expand Down Expand Up @@ -55,7 +56,6 @@
2BDA45C32ABEE840006128BC /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2BDA45C22ABEE840006128BC /* MessageView.swift */; };
2BE0EC222AA0956C00E47C52 /* ModalView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2BE0EC212AA0956C00E47C52 /* ModalView.swift */; };
2BE0EC272AA17F9100E47C52 /* MouseClickMonitor.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2BE0EC262AA17F9100E47C52 /* MouseClickMonitor.swift */; };
2BF558BE2AB8353B002F2008 /* LaunchAtLogin in Frameworks */ = {isa = PBXBuildFile; productRef = 2BF558BD2AB8353B002F2008 /* LaunchAtLogin */; };
2BF929792AB04D2600FC105B /* MemoManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2BF929782AB04D2600FC105B /* MemoManager.swift */; };
2BF9297C2AB13EEA00FC105B /* Markdown in Frameworks */ = {isa = PBXBuildFile; productRef = 2BF9297B2AB13EEA00FC105B /* Markdown */; };
2BF929802AB13F3600FC105B /* CodeBlockView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2BF9297F2AB13F3600FC105B /* CodeBlockView.swift */; };
Expand Down Expand Up @@ -144,8 +144,8 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
2B285D852ACA22FB000C5BDE /* LaunchAtLogin in Frameworks */,
2B4BDB892ACBED2100E55D78 /* SettingsAccess in Frameworks */,
2BF558BE2AB8353B002F2008 /* LaunchAtLogin in Frameworks */,
2BA3C2352AADAC5700537F95 /* llama in Frameworks */,
2B473E8C2AA860380042913D /* MenuBarExtraAccess in Frameworks */,
2BF929852AB13FEC00FC105B /* Highlighter in Frameworks */,
Expand Down Expand Up @@ -366,7 +366,7 @@
2BA3C2342AADAC5700537F95 /* llama */,
2BF9297B2AB13EEA00FC105B /* Markdown */,
2BF929842AB13FEC00FC105B /* Highlighter */,
2BF558BD2AB8353B002F2008 /* LaunchAtLogin */,
2B285D842ACA22FB000C5BDE /* LaunchAtLogin */,
2B4BDB882ACBED2100E55D78 /* SettingsAccess */,
);
productName = TypeaheadAI;
Expand Down Expand Up @@ -448,7 +448,7 @@
2BA3C2332AADAC5700537F95 /* XCRemoteSwiftPackageReference "llama" */,
2BF9297A2AB13EEA00FC105B /* XCRemoteSwiftPackageReference "swift-markdown" */,
2BF929832AB13FEC00FC105B /* XCRemoteSwiftPackageReference "HighlighterSwift" */,
2BF558BC2AB8353B002F2008 /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */,
2B285D832ACA22FB000C5BDE /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */,
2B4BDB872ACBED2100E55D78 /* XCRemoteSwiftPackageReference "SettingsAccess" */,
);
productRefGroup = 2BA7F0762A9ABBA8003D38BA /* Products */;
Expand Down Expand Up @@ -888,6 +888,14 @@
/* End XCConfigurationList section */

/* Begin XCRemoteSwiftPackageReference section */
2B285D832ACA22FB000C5BDE /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/sindresorhus/LaunchAtLogin-Modern";
requirement = {
branch = main;
kind = branch;
};
};
2B473E8A2AA860380042913D /* XCRemoteSwiftPackageReference "MenuBarExtraAccess" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/orchetect/MenuBarExtraAccess";
Expand Down Expand Up @@ -920,14 +928,6 @@
minimumVersion = 1.0.0;
};
};
2BF558BC2AB8353B002F2008 /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/sindresorhus/LaunchAtLogin-Modern";
requirement = {
branch = main;
kind = branch;
};
};
2BF9297A2AB13EEA00FC105B /* XCRemoteSwiftPackageReference "swift-markdown" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/apple/swift-markdown";
Expand All @@ -947,6 +947,11 @@
/* End XCRemoteSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
2B285D842ACA22FB000C5BDE /* LaunchAtLogin */ = {
isa = XCSwiftPackageProductDependency;
package = 2B285D832ACA22FB000C5BDE /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */;
productName = LaunchAtLogin;
};
2B473E8B2AA860380042913D /* MenuBarExtraAccess */ = {
isa = XCSwiftPackageProductDependency;
package = 2B473E8A2AA860380042913D /* XCRemoteSwiftPackageReference "MenuBarExtraAccess" */;
Expand All @@ -967,11 +972,6 @@
package = 2BA7F0AC2A9ABC47003D38BA /* XCRemoteSwiftPackageReference "KeyboardShortcuts" */;
productName = KeyboardShortcuts;
};
2BF558BD2AB8353B002F2008 /* LaunchAtLogin */ = {
isa = XCSwiftPackageProductDependency;
package = 2BF558BC2AB8353B002F2008 /* XCRemoteSwiftPackageReference "LaunchAtLogin-Modern" */;
productName = LaunchAtLogin;
};
2BF9297B2AB13EEA00FC105B /* Markdown */ = {
isa = XCSwiftPackageProductDependency;
package = 2BF9297A2AB13EEA00FC105B /* XCRemoteSwiftPackageReference "swift-markdown" */;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<key>TypeaheadAI.xcscheme_^#shared#^_</key>
<dict>
<key>orderHint</key>
<integer>1</integer>
<integer>0</integer>
</dict>
</dict>
</dict>
Expand Down
68 changes: 40 additions & 28 deletions TypeaheadAI/ClientManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,25 @@ class ClientManager {
version: "onboarding_v3"
)

if let result: Result<String, Error> = await self?.performOnboardingTask(payload: payload, timeout: timeout, streamHandler: streamHandler) {
completion(result)
} else {
completion(.failure(ClientManagerError.appError("Something went wrong...")))
do {
let stream = try self?.performOnboardingTask(
payload: payload,
timeout: timeout,
completion: completion
)

guard let stream = stream else {
self?.logger.debug("Failed to get stream")
streamHandler(.failure(ClientManagerError.networkError("Failed to connect")))
return
}

for try await text in stream {
self?.logger.debug("stream: \(text)")
streamHandler(.success(text))
}
} catch {
streamHandler(.failure(error))
}
}
}
Expand Down Expand Up @@ -592,37 +607,34 @@ class ClientManager {
private func performOnboardingTask(
payload: OnboardingRequestPayload,
timeout: TimeInterval,
streamHandler: @escaping (Result<String, Error>) -> Void
) async -> Result<String, Error> {
completion: @escaping (Result<String, Error>) -> Void
) throws -> AsyncThrowingStream<String, Error> {
guard let httpBody = try? JSONEncoder().encode(payload) else {
let error: Result<String, Error> = .failure(ClientManagerError.badRequest("Encoding error"))
streamHandler(error)
return error
throw ClientManagerError.badRequest("Encoding error")
}

var request = URLRequest(url: self.apiOnboarding, timeoutInterval: timeout)
request.httpMethod = "POST"
request.httpBody = httpBody
request.addValue("application/json", forHTTPHeaderField: "Content-Type")

var output = ""
do {
let (stream, _) = try await URLSession.shared.bytes(for: request)
return AsyncThrowingStream { continuation in
Task {
var urlRequest = URLRequest(url: self.apiOnboarding, timeoutInterval: timeout)
urlRequest.httpMethod = "POST"
urlRequest.httpBody = httpBody
urlRequest.addValue("application/json", forHTTPHeaderField: "Content-Type")

for try await line in stream.lines {
let decodedResponse = try JSONDecoder().decode(ChunkPayload.self, from: line.data(using: .utf8)!)
if let text = decodedResponse.text {
output += text
streamHandler(.success(text))
let (result, _) = try await URLSession.shared.bytes(for: urlRequest)
var output = ""
for try await line in result.lines {
if let data = line.data(using: .utf8),
let response = try? JSONDecoder().decode(ChunkPayload.self, from: data),
let text = response.text {
output += text
continuation.yield(text)
}
}

completion(.success(output))
continuation.finish()
}
} catch {
let err: Result<String, Error> = .failure(error)
streamHandler(err)
return err
}

return .success(output)
}

private func performStreamOfflineTask(
Expand Down
22 changes: 22 additions & 0 deletions TypeaheadAI/Llama/LlamaWrapper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ class LlamaWrapper {
return model != nil
}

func stream(
_ prompt: String
) -> AsyncThrowingStream<String, Error> {
ctx = llama_new_context_with_model(model, params)
return AsyncThrowingStream { continuation in
do {
try simple_predict(ctx, prompt, 1) { string in
continuation.yield(string)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}

guard let cstr = simple_predict(ctx, prompt, 1, globalHandler) else {
throw LlamaWrapperError.serverError("Failed to run simple_predict")
}


}

func predict(
_ prompt: String,
handler: @escaping (Result<String, Error>) -> Void
Expand Down
28 changes: 7 additions & 21 deletions TypeaheadAI/ModalManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ModalManager: ObservableObject {

/// Set an error message.
@MainActor
func setError(_ responseError: String) {
func setError(_ responseError: String) async {
isPending = false

if let idx = messages.indices.last, !messages[idx].isCurrentUser {
Expand Down Expand Up @@ -277,15 +277,11 @@ class ModalManager: ObservableObject {

switch result {
case .success(let chunk):
Task {
await self.appendText(chunk)
}
self.logger.info("Received chunk: \(chunk)")
await self.appendText(chunk)
case .failure(let error):
Task {
self.setError(error.localizedDescription)
}
self.logger.error("An error occurred: \(error)")
await self.setError(error.localizedDescription)
}
}, completion: defaultCompletionHandler)
}
Expand Down Expand Up @@ -482,28 +478,18 @@ class ModalManager: ObservableObject {
func defaultHandler(result: Result<String, Error>) {
switch result {
case .success(let chunk):
Task {
await self.appendText(chunk)
}
self.logger.info("Received chunk: \(chunk)")
await self.appendText(chunk)
case .failure(let error as ClientManagerError):
self.logger.error("Error: \(error.localizedDescription)")
switch error {
case .badRequest(let message):
DispatchQueue.main.async {
self.setError(message)
}
await self.setError(message)
default:
DispatchQueue.main.async {
self.setError("Something went wrong. Please try again.")
}
self.logger.error("Something went wrong.")
await self.setError("Something went wrong. Please try again.")
}
case .failure(let error):
self.logger.error("Error: \(error.localizedDescription)")
DispatchQueue.main.async {
self.setError(error.localizedDescription)
}
await self.setError(error.localizedDescription)
}
}
}
Expand Down
Loading