From bc153ea2ddac9a1457bd94e8797f0049e3c14e52 Mon Sep 17 00:00:00 2001 From: Towoadeyemi1 <38192891+Towoadeyemi1@users.noreply.github.com> Date: Thu, 20 Feb 2025 18:05:51 -0700 Subject: [PATCH 1/2] Add Ollama AI integration with local model support - Implemented Ollama proxy routes and client-side model detection - Added support for generating text using local Ollama models - Created new Ollama-specific routes and composable functions - Enhanced AI interaction handling to support Ollama provider - Added Ollama status detection and model availability checks --- config/app.js | 25 +++ config/handleAiInteractions.js | 290 ++++++++++++++++++++------------ public/App.js | 13 ++ public/components/AgentCard.js | 64 ++++++- public/composables/useModels.js | 117 +++++++++++-- server/routes/ollama.js | 60 +++++++ 6 files changed, 439 insertions(+), 130 deletions(-) create mode 100644 server/routes/ollama.js diff --git a/config/app.js b/config/app.js index 351891b..4c36029 100644 --- a/config/app.js +++ b/config/app.js @@ -41,8 +41,33 @@ app.use((req, res, next) => { next(); }); +// Add Ollama proxy route +app.post('/api/ollama/:model', async (req, res) => { + try { + const response = await fetch(`http://localhost:11434/api/generate`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: req.params.model, + prompt: req.body.prompt, + stream: false + }), + }); + + const data = await response.json(); + res.json({ text: data.response }); + } catch (error) { + console.error('Ollama API error:', error); + res.status(500).json({ error: 'Failed to communicate with Ollama' }); + } +}); +const ollamaRoutes = require('../server/routes/ollama'); +// Add Ollama routes +app.use('/api/ollama', ollamaRoutes); //Export the app for use on the index.js page module.exports = { app }; diff --git a/config/handleAiInteractions.js b/config/handleAiInteractions.js index e3ac024..f3876d4 100644 --- a/config/handleAiInteractions.js +++ b/config/handleAiInteractions.js @@ -8,36 +8,76 @@ const { Mistral } = require("@mistralai/mistralai"); const { GoogleGenerativeAI, HarmCategory, HarmBlockThreshold } = require("@google/generative-ai"); // Helper function to create provider-specific clients -const createClient = (provider, credentials) => { - const envKey = process.env[`${provider.toUpperCase()}_API_KEY`]; - const apiKey = credentials?.apiKey || envKey; - - if (!apiKey) { - throw new Error(`No API key available for ${provider}`); - } +const createClient = async (provider, modelConfig) => { + // Normalize provider to lowercase + const normalizedProvider = provider.toLowerCase(); + + // Define which providers need credentials + const requiresCredentials = { + 'openai': true, + 'anthropic': true, + 'azureai': true, + 'mistral': true, + 'groq': true, + 'gemini': true, + 'ollama': false + }; + + const envKey = process.env[`${provider.toUpperCase()}_API_KEY`]; + const credentials = modelConfig.apiKey || envKey; + + // Check credentials only for providers that require them + if (requiresCredentials[provider.toLowerCase()] && !credentials) { + throw new Error(`No API key available for ${provider}`); + } // console.log('LLM Request for ', provider) switch (provider.toLowerCase()) { case 'openai': - return new OpenAI({ apiKey }); + return new OpenAI({ apiKey: credentials }); case 'anthropic': - return new Anthropic({ apiKey }); + return new Anthropic({ apiKey: credentials }); case 'azureai': const endpoint = credentials?.apiEndpoint || process.env.AZUREAI_ENDPOINT; if (!endpoint) { throw new Error('AzureAI requires both an API key and endpoint. No endpoint was provided.'); } - if (!apiKey) { + if (!credentials) { throw new Error('AzureAI requires both an API key and endpoint. No API key was provided.'); } - return new OpenAIClient(endpoint, new AzureKeyCredential(apiKey)); + return new OpenAIClient(endpoint, new AzureKeyCredential(credentials)); case 'mistral': - return new Mistral({ apiKey }); + return new Mistral({ apiKey: credentials }); case 'groq': - return new Groq({ apiKey }); + return new Groq({ apiKey: credentials }); case 'gemini': - return new GoogleGenerativeAI(apiKey); + return new GoogleGenerativeAI(credentials); + case 'ollama': + console.log('Creating Ollama client with model:', modelConfig.model); + return { + provider: 'ollama', + model: modelConfig.model, + completions: { + create: async (config) => { + const response = await fetch('http://localhost:11434/api/generate', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model: modelConfig.model, + prompt: config.messages.map(m => `${m.role}: ${m.content}`).join('\n'), + stream: true + }) + }); + + if (!response.ok) { + throw new Error('Ollama generation failed'); + } + + return response.body; + } + } + }; default: throw new Error(`Unsupported provider: ${provider}`); } @@ -56,7 +96,7 @@ const validateMessages = (messages) => { const handlePrompt = async (promptConfig, sendToClient) => { const { - model: modelConfig, // Now expects the full model object + model: modelConfig, uuid, session, messageHistory, @@ -66,42 +106,26 @@ const handlePrompt = async (promptConfig, sendToClient) => { } = promptConfig; try { - // Create messages array if not provided in history const messages = messageHistory.length ? messageHistory : [ { role: "system", content: systemPrompt }, { role: "user", content: userPrompt }, ]; - // Validate message format - if (!validateMessages(messages)) { - throw new Error('Invalid message format in conversation history'); - } - - // Create provider-specific client - const client = createClient(modelConfig.provider, { - apiKey: modelConfig.apiKey, - apiEndpoint: modelConfig.apiEndpoint, - }); + const client = await createClient(modelConfig.provider, modelConfig); - //Create the prompt object to pass forward to the function - let promptPayload = { + let promptPayload = { model: modelConfig.model, messages, temperature: Math.max(0, Math.min(1, parseFloat(temperature) || 0.5)), stream: true, - } - - //Handle model specific limitations - if(modelConfig.model == 'o3-mini-2025-01-31') delete promptPayload.temperature; + }; - // Handle provider-specific prompts const responseStream = await handleProviderPrompt( client, modelConfig.provider, promptPayload ); - // Process the response stream await handleProviderResponse( responseStream, modelConfig.provider, @@ -111,9 +135,9 @@ const handlePrompt = async (promptConfig, sendToClient) => { ); } catch (error) { - console.error("Prompt error:", error); sendToClient(uuid, session, "ERROR", JSON.stringify({ - message: error.message || "An error occurred while processing the prompt" + message: error.message || "An error occurred while processing the prompt", + details: error.stack })); } }; @@ -145,6 +169,13 @@ const handleProviderPrompt = async (client, provider, config) => { case 'gemini': return handleGeminiPrompt(client, config); + case 'ollama': + if (!client?.completions?.create) { + console.error('Invalid client:', client); + throw new Error('Invalid Ollama client configuration'); + } + return client.completions.create(config); + default: throw new Error(`Unsupported provider: ${provider}`); } @@ -225,95 +256,130 @@ const handleGeminiPrompt = async (client, config) => { // Handle provider responses const handleProviderResponse = async (responseStream, provider, uuid, session, sendToClient) => { - // Normalize provider name to lowercase - provider = provider.toLowerCase(); + provider = provider.toLowerCase(); + + if (provider === 'ollama') { + try { + const reader = responseStream.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n').filter(line => line.trim()); + + for (const line of lines) { + try { + const response = JSON.parse(line); + if (response.response) { + sendToClient(uuid, session, "message", response.response); + } + } catch (e) { + // Silent catch for parsing errors + } + } + } + + sendToClient(uuid, session, "EOM", null); + + } catch (error) { + sendToClient(uuid, session, "ERROR", JSON.stringify({ + message: "Error processing Ollama stream", + error: error.message + })); + } + return; + } - // Handle Gemini separately - if (provider === "gemini") { - for await (const chunk of responseStream.stream) { - sendToClient(uuid, session, "message", chunk.text()); + // Handle Gemini separately + if (provider === "gemini") { + for await (const chunk of responseStream.stream) { + sendToClient(uuid, session, "message", chunk.text()); + } + sendToClient(uuid, session, "EOM", null); + return; } - sendToClient(uuid, session, "EOM", null); - return; - } - // Handle Azure separately - if (provider === "azureai") { - const stream = Readable.from(responseStream); - handleAzureStream(stream, uuid, session, sendToClient); - return; - } + // Handle Azure separately + if (provider === "azureai") { + const stream = Readable.from(responseStream); + handleAzureStream(stream, uuid, session, sendToClient); + return; + } - // Handle other providers - let messageEnded = false; - for await (const part of responseStream) { - try { - let content = null; - - switch (provider) { - case "openai": - content = part?.choices?.[0]?.delta?.content; - messageEnded = part?.choices?.[0]?.finish_reason === "stop"; - break; - case "anthropic": - if (part.type === "message_stop") { - messageEnded = true; - } else { - content = part?.content_block?.text || part?.delta?.text || ""; - } - break; - case "mistral": - content = part?.data?.choices?.[0]?.delta?.content; - messageEnded = part?.data?.choices?.[0]?.finishReason === "stop"; - break; - case "groq": - content = part?.choices?.[0]?.delta?.content; - messageEnded = part?.choices?.[0]?.finish_reason === "stop"; - break; - } + // Handle other providers + let messageEnded = false; + for await (const part of responseStream) { + try { + let content = null; + + switch (provider) { + case "openai": + content = part?.choices?.[0]?.delta?.content; + messageEnded = part?.choices?.[0]?.finish_reason === "stop"; + break; + case "anthropic": + if (part.type === "message_stop") { + messageEnded = true; + } else { + content = part?.content_block?.text || part?.delta?.text || ""; + } + break; + case "mistral": + content = part?.data?.choices?.[0]?.delta?.content; + messageEnded = part?.data?.choices?.[0]?.finishReason === "stop"; + break; + case "groq": + content = part?.choices?.[0]?.delta?.content; + messageEnded = part?.choices?.[0]?.finish_reason === "stop"; + break; + } + + if (content) { + sendToClient(uuid, session, "message", content); + } + + // Send EOM if we've reached the end of the message + if (messageEnded) { + sendToClient(uuid, session, "EOM", null); + } + } catch (error) { + console.error(`Error processing ${provider} stream message:`, error); + sendToClient(uuid, session, "ERROR", JSON.stringify({ + message: "Error processing stream message", + error: error.message, + provider: provider + })); + } + } - if (content) { - sendToClient(uuid, session, "message", content); - } - - // Send EOM if we've reached the end of the message - if (messageEnded) { + // Send final EOM if not already sent + if (!messageEnded) { sendToClient(uuid, session, "EOM", null); - } - } catch (error) { - console.error(`Error processing ${provider} stream message:`, error); - sendToClient(uuid, session, "ERROR", JSON.stringify({ - message: "Error processing stream message", - error: error.message, - provider: provider - })); } - } - - // Send final EOM if not already sent - if (!messageEnded) { - sendToClient(uuid, session, "EOM", null); - } }; + // Handle AzureAI specific stream const handleAzureStream = (stream, uuid, session, sendToClient) => { - stream.on("data", (event) => { - event.choices.forEach((choice) => { - if (choice.delta?.content !== undefined) { - sendToClient(uuid, session, "message", choice.delta.content); - } + stream.on("data", (event) => { + event.choices.forEach((choice) => { + if (choice.delta?.content !== undefined) { + sendToClient(uuid, session, "message", choice.delta.content); + } + }); }); - }); - stream.on("end", () => sendToClient(uuid, session, "EOM", null)); - stream.on("error", (error) => { - sendToClient(uuid, session, "ERROR", JSON.stringify({ - message: "Stream error.", - error: error.message - })); - }); + stream.on("end", () => sendToClient(uuid, session, "EOM", null)); + stream.on("error", (error) => { + sendToClient(uuid, session, "ERROR", JSON.stringify({ + message: "Stream error.", + error: error.message + })); + }); }; module.exports = { - handlePrompt + handlePrompt }; \ No newline at end of file diff --git a/public/App.js b/public/App.js index 0900b7a..7f52989 100644 --- a/public/App.js +++ b/public/App.js @@ -111,9 +111,22 @@ export default { fileInput.value.click(); } + // Add Ollama model detection + const detectOllamaModels = async () => { + try { + const response = await fetch('http://localhost:11434/api/tags'); + const data = await response.json(); + return data.models || []; + } catch (error) { + console.warn('Ollama not detected locally:', error); + return []; + } + }; + Vue.onMounted(async ()=>{ await getConfigs(); await fetchServerModels(); + await detectOllamaModels(); // Add Ollama model detection await socketIoConnection(); }); diff --git a/public/components/AgentCard.js b/public/components/AgentCard.js index 25f07a0..d3913e0 100644 --- a/public/components/AgentCard.js +++ b/public/components/AgentCard.js @@ -189,6 +189,12 @@ export default { + + +