diff --git a/Cargo.lock b/Cargo.lock index d5dd75d..37d6f07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,13 +2,109 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "libc" +version = "0.2.179" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a2d376baa530d1238d133232d15e239abad80d05838b4b59354e5268af431f" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "parrot" version = "0.0.2" dependencies = [ + "async-trait", "thiserror", + "tokio", ] +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "proc-macro2" version = "1.0.104" @@ -27,6 +123,47 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + [[package]] name = "syn" version = "2.0.111" @@ -58,8 +195,131 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" diff --git a/Cargo.toml b/Cargo.toml index 0ada7a4..1e52be9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,8 @@ version = "0.0.2" edition = "2024" [dependencies] +async-trait = "0.1.89" thiserror = "2.0.17" + +[dev-dependencies] +tokio = { version = "1.49.0", features = ["full"] } diff --git a/examples/basic.rs b/examples/basic.rs new file mode 100644 index 0000000..6388388 --- /dev/null +++ b/examples/basic.rs @@ -0,0 +1,15 @@ +use parrot::llm::get_available_models; + +#[tokio::main] +async fn main() { + let available_models = get_available_models().expect("failed to get models"); + + for model in available_models.into_iter() { + let out = model + .prompt("reply with a check mark if this request is successful") + .await + .expect("failed to prompt"); + + println!("{} - {}", model.get_name(), out.trim()); + } +} diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 8d992c1..d2a7de5 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -2,6 +2,8 @@ use crate::error::LLMError; use super::{Model, ModelFactory, constants::names}; +use async_trait::async_trait; + use std::env::var; const ANTHROPIC_API_KEY: &str = "ANTHROPIC_API_KEY"; @@ -21,12 +23,13 @@ impl ModelFactory for Anthropic { } } +#[async_trait] impl Model for Anthropic { fn get_name(&self) -> String { names::ANTHROPIC.into() } - fn prompt(&self, _: &str) -> Result { + async fn prompt(&self, _: &str) -> Result { unimplemented!("anthropic api") } } diff --git a/src/llm/claude.rs b/src/llm/claude.rs index b3300da..2e1b063 100644 --- a/src/llm/claude.rs +++ b/src/llm/claude.rs @@ -1,5 +1,7 @@ use std::process::Command; +use async_trait::async_trait; + use crate::{ error::LLMError, llm::{Model, ModelFactory, constants::names}, @@ -20,12 +22,13 @@ impl ModelFactory for Claude { } } +#[async_trait] impl Model for Claude { fn get_name(&self) -> String { names::CLAUDE.into() } - fn prompt(&self, input: &str) -> Result { + async fn prompt(&self, input: &str) -> Result { let out = Command::new(CLAUDE_CLI_NAME) .args([input, "-p"]) .output() diff --git a/src/llm/cursor.rs b/src/llm/cursor.rs index 83011eb..c5bcc54 100644 --- a/src/llm/cursor.rs +++ b/src/llm/cursor.rs @@ -1,3 +1,5 @@ +use async_trait::async_trait; + use crate::{ error::LLMError, llm::{Model, ModelFactory, constants::names}, @@ -20,12 +22,13 @@ impl ModelFactory for CursorCLI { } } +#[async_trait] impl Model for CursorCLI { fn get_name(&self) -> String { names::CURSOR.into() } - fn prompt(&self, input: &str) -> Result { + async fn prompt(&self, input: &str) -> Result { let out = Command::new(CURSOR_CLI_NAME) .args([input, "-p"]) .output() diff --git a/src/llm/mod.rs b/src/llm/mod.rs index bb855d7..8bb4022 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -10,6 +10,8 @@ mod openai; /// model names in external projects. pub mod constants; +use async_trait::async_trait; + use crate::{ error::LLMError, llm::{anthropic::Anthropic, claude::Claude, cursor::CursorCLI, openai::OpenAI}, @@ -21,7 +23,7 @@ use crate::{ /// Note: It's required to separate this from the actual model /// trait, that defines the interface for interaction with the /// LLMs. This is because we store the actual `Model` in a boxed -/// vector. +/// vector in the `get_available_models` method. /// /// To allow this, the instances of `Model` have to be `dyn`-compatible, /// which requires the trait not to be `Sized` as it is described here: @@ -32,9 +34,10 @@ pub trait ModelFactory: Model + Sized { /// Defines the required functionality /// to interact with a language model. +#[async_trait] pub trait Model: Send + Sync { fn get_name(&self) -> String; - fn prompt(&self, input: &str) -> Result; + async fn prompt(&self, input: &str) -> Result; } /// Returns the available models in the current diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 6b98896..b6980f3 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -1,5 +1,7 @@ use std::env::var; +use async_trait::async_trait; + use crate::{ error::LLMError, llm::{Model, ModelFactory, constants::names}, @@ -22,12 +24,13 @@ impl ModelFactory for OpenAI { } } +#[async_trait] impl Model for OpenAI { fn get_name(&self) -> String { names::OPENAI.into() } - fn prompt(&self, _: &str) -> Result { + async fn prompt(&self, _: &str) -> Result { unimplemented!("open ai api") } } diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 2230272..0000000 --- a/src/main.rs +++ /dev/null @@ -1,13 +0,0 @@ -use parrot::llm::get_available_models; - -fn main() { - let available_models = get_available_models().expect("failed to get models"); - - available_models.iter().for_each(|m| { - let out = m - .prompt("say hello to my friends") - .expect("failed to prompt"); - - println!("{} - {}", m.get_name(), out); - }) -}