diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index ddd0873..9cbe6bf 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -4,8 +4,6 @@ on: push: branches: [master] tags: ['v*'] - pull_request: - branches: [master] workflow_dispatch: jobs: diff --git a/Cargo.lock b/Cargo.lock index 366e96b..b97df20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "OxideChat" -version = "0.1.0" +version = "0.1.1" dependencies = [ "aes-gcm", "arc-swap", @@ -17,6 +17,7 @@ dependencies = [ "dotenv", "extism", "futures-util", + "infer", "oauth2", "omniference", "rand 0.8.5", @@ -544,6 +545,17 @@ dependencies = [ "shlex", ] +[[package]] +name = "cfb" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38f2da7a0a2c4ccf0065be06397cc26a81f4e528be095826eee9d4adbb8c60f" +dependencies = [ + "byteorder", + "fnv", + "uuid", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -993,7 +1005,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1744,6 +1756,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "infer" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc150e5ce2330295b8616ce0e3f53250e53af31759a9dbedad1621ba29151847" +dependencies = [ + "cfb", +] + [[package]] name = "inout" version = "0.1.4" @@ -2023,6 +2044,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2067,7 +2098,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2156,7 +2187,7 @@ dependencies = [ [[package]] name = "omniference" -version = "0.1.1" +version = "0.1.3" dependencies = [ "anyhow", "async-stream", @@ -2576,7 +2607,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -2782,6 +2813,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "percent-encoding", "pin-project-lite", @@ -2941,7 +2973,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3577,7 +3609,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix 1.1.3", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4026,6 +4058,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -4750,7 +4788,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6e9e82c..938aba2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "OxideChat" -version = "0.1.0" +version = "0.1.1" edition = "2024" description = "A modern, performance and on customization focused AI chat application" license = "MIT" -repository = "https://github.com/yourusername/OxideChat" +repository = "https://github.com/NxtCore/OxideChat" readme = "README.md" keywords = ["chat", "axum", "websocket", "postgres"] categories = ["web-programming"] @@ -32,6 +32,7 @@ oauth2 = "5" reqwest = { version = "0.12", default-features = false, features = [ "rustls-tls", "json", + "multipart", ] } tower-http = { version = "0.5", features = ["cors"] } omniference = "0.1" @@ -46,6 +47,7 @@ sha2 = "0.10" base64 = "0.22" thiserror = "2" async-trait = "0.1" +infer = "0.16" [lints.clippy] correctness = { level = "deny", priority = -1 } diff --git a/Dockerfile b/Dockerfile index d5ec6b8..b49d4be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -69,7 +69,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && groupadd -g 1000 app \ && useradd -u 1000 -g app -s /bin/bash -m app \ && mkdir -p /var/cache/nginx /var/log/nginx /var/lib/nginx /run /tmp/nginx \ - && chown -R app:app /var/cache/nginx /var/log/nginx /var/lib/nginx /run /tmp/nginx + && mkdir -p /app/uploads/images \ + && chown -R app:app /var/cache/nginx /var/log/nginx /var/lib/nginx /run /tmp/nginx /app/uploads COPY --from=frontend-builder /usr/local/bin/bun /usr/local/bin/bun @@ -106,6 +107,11 @@ RUN printf '%s\n' \ USER app +# Image storage configuration: +# IMAGE_STORAGE_TYPE=database (default) | file +# IMAGE_STORAGE_PATH=/app/uploads/images (default for file storage) +VOLUME ["/app/uploads"] + EXPOSE 8080 HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \ diff --git a/frontend/components/ImagePreview.vue b/frontend/components/ImagePreview.vue new file mode 100644 index 0000000..d0e8a29 --- /dev/null +++ b/frontend/components/ImagePreview.vue @@ -0,0 +1,148 @@ + + + + + diff --git a/frontend/components/chat/ChatComposer.vue b/frontend/components/chat/ChatComposer.vue index cdf2176..cb6c989 100644 --- a/frontend/components/chat/ChatComposer.vue +++ b/frontend/components/chat/ChatComposer.vue @@ -2,19 +2,40 @@
+
+
+ Attached image + +
+
+ + diff --git a/frontend/components/chat/ChatEmptyState.vue b/frontend/components/chat/ChatEmptyState.vue index e205cc8..b646d17 100644 --- a/frontend/components/chat/ChatEmptyState.vue +++ b/frontend/components/chat/ChatEmptyState.vue @@ -8,7 +8,7 @@ {{ description }}

- +
@@ -20,7 +20,7 @@ import {useChatStore} from '~/stores/chatStore'; import {useMainStore} from '~/stores'; const emit = defineEmits<{ - send: [content: string]; + send: [content: string, parts?: any[]]; }>(); const store = useMainStore(); @@ -45,12 +45,8 @@ const description = computed(() => { return placeholders[Math.floor(Math.random() * placeholders.length)]; }); -function handleSend(content: string) { - emit('send', content); -} - -function onSend(content: string | undefined) { - if (!content) return; - handleSend(content); +function onSend(content: string, parts?: any[]) { + if (!content && (!parts || parts.length === 0)) return; + emit('send', content, parts); } diff --git a/frontend/components/chat/ChatView.vue b/frontend/components/chat/ChatView.vue index 028d02c..69c4b9c 100644 --- a/frontend/components/chat/ChatView.vue +++ b/frontend/components/chat/ChatView.vue @@ -10,7 +10,7 @@ import ChatEmptyState from './ChatEmptyState.vue'; const chatStore = useChatStore(); -async function handleSendMessage(content: string) { +async function handleSendMessage(content: string, parts?: any[]) { let chatId = chatStore.activeChat?.id; if (!chatId) { @@ -21,6 +21,6 @@ async function handleSendMessage(content: string) { chatId = chat.id; } - await chatStore.sendAndStream(chatId, content); + await chatStore.sendAndStream(chatId, content, parts); } diff --git a/frontend/components/chat/MessageItem.vue b/frontend/components/chat/MessageItem.vue index 3005e79..a4eef41 100644 --- a/frontend/components/chat/MessageItem.vue +++ b/frontend/components/chat/MessageItem.vue @@ -39,11 +39,22 @@ :name="tool.tool_name" :args="tool.input_args" :output="tool.output" - :error="tool.error" + :error="tool.error || undefined" :is-executing="!tool.output && !tool.error" />
+
+
+ +
+
+
+
@@ -78,6 +90,7 @@ import {User, Bot, Brain, ChevronDown} from 'lucide-vue-next'; import MessageActions from './MessageActions.vue'; import CodePreview from './CodePreview.vue'; +import ImagePreview from '~/components/ImagePreview.vue'; import ToolExecutionDisplay from './ToolExecutionDisplay.vue'; import type {ChatMessage} from '~/types/chat'; import {useChatStore} from '~/stores/chatStore'; @@ -98,11 +111,26 @@ const {renderStreaming, renderComplete} = useMarkdown(); const showReasoning = ref(false); const previewData = ref<{code: string; language: string} | null>(null); +const showImagePreview = ref(false); +const imagePreviewUrl = ref(null); +const imagePreviewFilename = ref(undefined); const isUser = computed(() => props.message.role === 'user'); const isStreaming = computed(() => props.message.id.startsWith('streaming-')); const isStreamingReasoning = computed(() => isStreaming.value && chatStore.isStreaming && !props.message.content); +const attachedImages = computed(() => { + if (!props.message.content_parts || !Array.isArray(props.message.content_parts)) { + return []; + } + return props.message.content_parts + .filter((part: any) => part.type === 'image' && part.image_id) + .map((part: any) => ({ + image_id: part.image_id, + url: `/api/v1/images/${part.image_id}`, + })); +}); + const model = computed(() => { if (isUser.value) return null; return chatStore.models.find(m => m.id === props.message.model_id); @@ -188,11 +216,30 @@ function handleCodeBlockClick(event: MouseEvent) { previewData.value = result; } } + + const imgTag = target.closest('img') as HTMLImageElement | null; + if (imgTag && imgTag.src) { + const alt = imgTag.alt || 'image'; + const filename = `${alt.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9-_]/g, '')}.png`; + openImagePreview(imgTag.src, filename); + } } function closePreview() { previewData.value = null; } + +function openImagePreview(url: string, filename?: string) { + imagePreviewUrl.value = url; + imagePreviewFilename.value = filename; + showImagePreview.value = true; +} + +function closeImagePreview() { + showImagePreview.value = false; + imagePreviewUrl.value = null; + imagePreviewFilename.value = undefined; +} diff --git a/frontend/components/chat/MessageList.vue b/frontend/components/chat/MessageList.vue index d77a839..0ca4204 100644 --- a/frontend/components/chat/MessageList.vue +++ b/frontend/components/chat/MessageList.vue @@ -15,11 +15,10 @@
- diff --git a/frontend/components/chat/ToolExecutionDisplay.vue b/frontend/components/chat/ToolExecutionDisplay.vue index b086ced..1d407f8 100644 --- a/frontend/components/chat/ToolExecutionDisplay.vue +++ b/frontend/components/chat/ToolExecutionDisplay.vue @@ -1,6 +1,14 @@ diff --git a/frontend/composables/useMarkdown.ts b/frontend/composables/useMarkdown.ts index 7aa69a4..82f4e54 100644 --- a/frontend/composables/useMarkdown.ts +++ b/frontend/composables/useMarkdown.ts @@ -12,10 +12,23 @@ import {Marked} from 'marked'; import remend from 'remend'; function sanitize(html: string): string { - return DOMPurify.sanitize(html, { - ADD_ATTR: ['data-language', 'data-previewable', 'title'], - ADD_TAGS: ['button'], + DOMPurify.addHook('afterSanitizeAttributes', node => { + if (node.tagName === 'A' && node.hasAttribute('href')) { + const href = node.getAttribute('href') || ''; + if (href.startsWith('data:')) { + node.removeAttribute('href'); + } + } + }); + + const result = DOMPurify.sanitize(html, { + ADD_ATTR: ['data-language', 'data-previewable', 'title', 'src', 'alt'], + ADD_TAGS: ['button', 'img'], + ALLOWED_URI_REGEXP: /^(?:(?:(?:f|ht)tps?|mailto|tel|callto|sms|cid|xmpp|data):|[^a-z]|[a-z+.\-]+(?:[^a-z+.\-:]|$))/i, }); + + DOMPurify.removeHook('afterSanitizeAttributes'); + return result; } // Language aliases for normalization diff --git a/frontend/pages/chats/[id].vue b/frontend/pages/chats/[id].vue index e52d2c3..030a62e 100644 --- a/frontend/pages/chats/[id].vue +++ b/frontend/pages/chats/[id].vue @@ -14,7 +14,7 @@ import {useRoute} from '#app'; const chatStore = useChatStore(); const route = useRoute(); -async function handleSendMessage(content: string) { +async function handleSendMessage(content: string, parts?: any[]) { let chatId = chatStore.activeChat?.id; if (!chatId) { const chat = await chatStore.createChat({ @@ -23,7 +23,7 @@ async function handleSendMessage(content: string) { if (!chat) return; chatId = chat.id; } - await chatStore.sendAndStream(chatId, content); + await chatStore.sendAndStream(chatId, content, parts); } if (!chatStore.activeChat || chatStore.activeChat.id !== route.params.id) { diff --git a/frontend/pages/settings/tools.vue b/frontend/pages/settings/tools.vue index 71e95ff..9ace36f 100644 --- a/frontend/pages/settings/tools.vue +++ b/frontend/pages/settings/tools.vue @@ -457,6 +457,69 @@ const builtinToolTemplates = [ }, }, }, + { + name: 'imagegen', + display_name: 'Image Generation', + description: 'Generate and edit images using OpenAI, Replicate, or Google APIs', + source_kind: 'BUILTIN', + icon: Sparkles, + source_config: {builtin_id: 'imagegen'}, + functions: [ + { + name: 'generate', + description: 'Generate an image from a text prompt', + input_schema: { + type: 'object', + properties: { + prompt: {type: 'string', description: 'The text prompt describing the image to generate'}, + size: { + type: 'string', + description: 'Image size', + enum: ['1024x1024', '1792x1024', '1024x1792', '512x512', '256x256'], + default: '1024x1024', + }, + quality: { + type: 'string', + description: 'Image quality', + enum: ['standard', 'hd'], + default: 'standard', + }, + }, + required: ['prompt'], + }, + }, + { + name: 'edit', + description: 'Edit an existing image using a text prompt', + input_schema: { + type: 'object', + properties: { + image_url: {type: 'string', description: 'URL of the image to edit'}, + prompt: {type: 'string', description: 'The text prompt describing the desired edit'}, + }, + required: ['image_url', 'prompt'], + }, + }, + ], + settings_schema: { + type: 'object', + required: ['api_key', 'provider'], + properties: { + api_key: {type: 'string', title: 'API Key', secret: true, description: 'API key for the selected provider'}, + provider: { + type: 'string', + title: 'Provider', + enum: ['openai', 'replicate', 'google'], + description: 'Image generation provider to use', + }, + model: { + type: 'string', + title: 'Model', + description: 'Model to use (optional, defaults: dall-e-3, flux-schnell, imagen-3)', + }, + }, + }, + }, ]; const displayTools = computed(() => { @@ -717,7 +780,8 @@ async function saveTool() { body_template: httpConfig.body_template || null, }; } else if (toolForm.source_kind === 'BUILTIN') { - source_config = {builtin_id: builtinConfig.builtin_id}; + const builtin = displayTools.value.find(t => t.name === toolForm.name); + source_config = {builtin_id: builtin?.source_config?.builtin_id || builtin?.name || toolForm.name}; } else if (toolForm.source_kind === 'WASM') { source_config = { wasm_blob_id: wasmConfig.blob_id, diff --git a/frontend/stores/chatStore.ts b/frontend/stores/chatStore.ts index 32eff6a..2901506 100644 --- a/frontend/stores/chatStore.ts +++ b/frontend/stores/chatStore.ts @@ -328,7 +328,7 @@ export const useChatStore = defineStore('chat', { } }, - async sendAndStream(chatId: string, content: string): Promise { + async sendAndStream(chatId: string, content: string, parts?: any[]): Promise { if (!this.selectedModel) { console.error('No model selected'); return; @@ -395,19 +395,25 @@ export const useChatStore = defineStore('chat', { const config = useRuntimeConfig(); const baseUrl = config.public.apiBase || ''; + const body: any = { + content, + model_key: this.selectedModel.model_id, + reasoning_effort: this.reasoningEffort || undefined, + reasoning_budget_tokens: this.reasoningBudget || undefined, + tools_enabled: this.enabledTools.length > 0 ? this.enabledTools : undefined, + }; + + if (parts && parts.length > 0) { + body.parts = parts; + } + const response = await fetch(`${baseUrl}/api/v1/chats/${chatId}/stream`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, credentials: 'include', - body: JSON.stringify({ - content, - model_key: this.selectedModel.model_id, - reasoning_effort: this.reasoningEffort || undefined, - reasoning_budget_tokens: this.reasoningBudget || undefined, - tools_enabled: this.enabledTools.length > 0 ? this.enabledTools : undefined, - }), + body: JSON.stringify(body), }); if (!response.ok) { @@ -433,7 +439,7 @@ export const useChatStore = defineStore('chat', { for (const line of lines) { if (line.startsWith('data: ')) { - const jsonStr = line.slice(6); + const jsonStr = line.slice(6).trim(); if (!jsonStr) continue; try { @@ -474,7 +480,15 @@ export const useChatStore = defineStore('chat', { const toolCall = msg?.tool_calls?.find(tc => tc.tool_call_id === data.id); if (toolCall && typeof toolCall.input_args === 'string') { toolCall.input_args += data.args_delta; + } else { + const newToolCall = { + tool_call_id: data.id, + tool_name: data.name, + input_args: data.args_delta, + }; + if (msg) msg.tool_calls.push(newToolCall as any); } + break; } case 'tool_call_end': { diff --git a/frontend/stores/icons.ts b/frontend/stores/icons.ts index 0da4164..15292cd 100644 --- a/frontend/stores/icons.ts +++ b/frontend/stores/icons.ts @@ -88,7 +88,6 @@ export const useIconsStore = defineStore('icons', { const provider = state.providersMeta[state.providerLookup[normalized]]; return {icon: provider.icon, type: provider.type as 'svg' | 'png'}; } - console.log(normalized); for (const [key, meta] of Object.entries(state.providersMeta)) { if ( meta.variants.some(variant => { diff --git a/frontend/types/chat.ts b/frontend/types/chat.ts index fc45b56..b63f2de 100644 --- a/frontend/types/chat.ts +++ b/frontend/types/chat.ts @@ -47,6 +47,7 @@ export interface ChatMessage { content: string; reasoning_content: string | null; model_id: string | null; + content_parts?: Array<{type: string; text?: string; image_id?: string}> | null; cost_details: { input: string | null; output: string | null; diff --git a/migrations/20251226000000_initial_schema.sql b/migrations/20251226000000_initial_schema.sql index e383236..1848a13 100644 --- a/migrations/20251226000000_initial_schema.sql +++ b/migrations/20251226000000_initial_schema.sql @@ -230,7 +230,7 @@ CREATE TABLE IF NOT EXISTS model_access ( ); -- Workspaces (linked to users) -CREATE TABLE workspaces ( +CREATE TABLE IF NOT EXISTS workspaces ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, name VARCHAR(100) NOT NULL, @@ -244,7 +244,7 @@ CREATE TABLE workspaces ( ); -- Chats (linked to workspaces) -CREATE TABLE chats ( +CREATE TABLE IF NOT EXISTS chats ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL, @@ -256,7 +256,7 @@ CREATE TABLE chats ( ); -- Messages -CREATE TABLE messages ( +CREATE TABLE IF NOT EXISTS messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE, role VARCHAR(20) NOT NULL, @@ -265,13 +265,14 @@ CREATE TABLE messages ( model_id UUID REFERENCES models(id) ON DELETE SET NULL, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW(), + content_parts JSONB DEFAULT '[]', cost_details JSONB DEFAULT '{}', usage_details JSONB DEFAULT '{}', reasoning_details JSONB DEFAULT '{}' ); -- User preferences (streaming animation, default model, etc.) -CREATE TABLE user_preferences ( +CREATE TABLE IF NOT EXISTS user_preferences ( user_id UUID PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, default_model_key VARCHAR(255), favorite_model_keys JSONB DEFAULT '[]', @@ -283,6 +284,16 @@ CREATE TABLE user_preferences ( updated_at TIMESTAMPTZ DEFAULT NOW() ); +CREATE TABLE IF NOT EXISTS images ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + data BYTEA, -- NULL if using file storage + file_path VARCHAR(500), -- Path relative to storage root (for file storage) + mime_type VARCHAR(64) NOT NULL DEFAULT 'image/png', + size_bytes BIGINT NOT NULL, + source VARCHAR(50), -- 'imagegen', 'upload', etc. + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + INSERT INTO roles (name) VALUES ('admin'), ('user') ON CONFLICT DO NOTHING; INSERT INTO permissions (name, description) VALUES @@ -481,6 +492,9 @@ CREATE INDEX IF NOT EXISTS idx_chats_updated ON chats(updated_at DESC); CREATE INDEX IF NOT EXISTS idx_chats_pinned ON chats(user_id, is_pinned) WHERE is_pinned = true; CREATE INDEX IF NOT EXISTS idx_messages_chat ON messages(chat_id); CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(created_at); +CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_images_user_id ON images(user_id); +CREATE INDEX IF NOT EXISTS idx_messages_content_parts_gin ON messages USING GIN (content_parts); -- Tool indexes CREATE INDEX IF NOT EXISTS idx_wasm_blobs_owner ON wasm_blobs(owner_id); @@ -1121,6 +1135,18 @@ INSERT INTO i18n_translations (language, key_path, value) VALUES ('de', 'chat.tool_execution.error', 'Fehler'), ('de', 'chat.tool_execution.completed_in', 'Abgeschlossen in {ms}ms'), + -- Image Generation + ('en', 'chat.image_preview.download', 'Download'), + ('en', 'chat.image_preview.copy', 'Copy URL'), + ('en', 'chat.image_preview.copied', 'Copied!'), + ('en', 'chat.tool_execution.generated_image', 'Generated Image'), + + ('de', 'chat.image_preview.download', 'Herunterladen'), + ('de', 'chat.image_preview.copy', 'URL kopieren'), + ('de', 'chat.image_preview.copied', 'Kopiert!'), + ('de', 'chat.tool_execution.generated_image', 'Generiertes Bild'); + + -- Schema Builder ('en', 'settings.schema_builder.type', 'Type'), ('en', 'settings.schema_builder.default', 'Default'), diff --git a/src/routes/admin/tools.rs b/src/routes/admin/tools.rs index 09602f8..358274f 100644 --- a/src/routes/admin/tools.rs +++ b/src/routes/admin/tools.rs @@ -535,6 +535,7 @@ pub async fn test_tool(State(state): State>, cookies: Cookies, Pat settings: tool_settings.map(|s| s.settings).unwrap_or_default(), timeout_ms: Some(30000), function_name: req.function_name.clone(), + db: Some(std::sync::Arc::new(state.db.clone())), }; let start = std::time::Instant::now(); diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 9149bb3..aed5200 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -73,5 +73,8 @@ pub fn build_router() -> Router> { .route("/api/v1/models", get(public::models::list_models)) // Tools .route("/api/v1/tools", get(public::tools::list_tools)) + // Images CDN (public, no auth) + .route("/api/v1/images/{id}", get(public::images::serve_image)) + .route("/api/v1/images", post(public::images::upload_image)) .layer(CookieManagerLayer::new()) } diff --git a/src/routes/public/images.rs b/src/routes/public/images.rs new file mode 100644 index 0000000..a43f1cb --- /dev/null +++ b/src/routes/public/images.rs @@ -0,0 +1,76 @@ +//! Image CDN routes for serving and uploading images. +//! +//! Provides public endpoints for image storage: +//! - GET /api/v1/images/:id - Serve an image by UUID +//! - POST /api/v1/images - Upload a base64 image (internal use) + +use crate::AppState; +use crate::utils::images::{get_image, image_url, store_from_data_uri}; +use axum::{ + Json, + extract::{Path, State}, + http::{HeaderMap, StatusCode, header}, + response::{IntoResponse, Response}, +}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use uuid::Uuid; + +/// Request body for uploading an image +#[derive(Debug, Deserialize)] +pub struct UploadImageRequest { + /// Base64 data URI (e.g., "data:image/png;base64,...") + pub data_uri: String, + /// Optional user ID for attribution + pub user_id: Option, + /// Optional source identifier (e.g., "imagegen") + pub source: Option, +} + +/// Response for successful image upload +#[derive(Debug, Serialize)] +pub struct UploadImageResponse { + pub id: Uuid, + pub url: String, + pub mime_type: String, + pub size_bytes: i64, +} + +/// Upload a base64 image and return its URL +/// +/// POST /api/images +pub async fn upload_image(State(state): State>, Json(req): Json) -> Result, (StatusCode, String)> { + let stored = store_from_data_uri(&state.db, &req.data_uri, req.user_id, req.source.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + + Ok(Json(UploadImageResponse { + id: stored.id, + url: image_url(stored.id), + mime_type: stored.mime_type, + size_bytes: stored.size_bytes, + })) +} + +/// Serve an image by ID +/// +/// GET /api/v1/images/:id +pub async fn serve_image(State(state): State>, Path(id): Path) -> Response { + match get_image(&state.db, id).await { + Ok(Some((data, mime_type))) => { + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + mime_type.parse().unwrap_or(header::HeaderValue::from_static("application/octet-stream")), + ); + headers.insert(header::CACHE_CONTROL, header::HeaderValue::from_static("public, max-age=31536000, immutable")); + + (StatusCode::OK, headers, data).into_response() + } + Ok(None) => (StatusCode::NOT_FOUND, "Image not found").into_response(), + Err(e) => { + eprintln!("[IMAGES] Failed to retrieve image {id}: {e}"); + (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error").into_response() + } + } +} diff --git a/src/routes/public/mod.rs b/src/routes/public/mod.rs index 6b520b9..3ceea0b 100644 --- a/src/routes/public/mod.rs +++ b/src/routes/public/mod.rs @@ -1,6 +1,7 @@ pub mod auth; pub mod base; pub mod chats; +pub mod images; pub mod messages; pub mod models; pub mod oauth; diff --git a/src/routes/public/streaming.rs b/src/routes/public/streaming.rs index bf35fea..bdd9e3b 100644 --- a/src/routes/public/streaming.rs +++ b/src/routes/public/streaming.rs @@ -27,10 +27,20 @@ use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Instant}; use tower_cookies::Cookies; use uuid::Uuid; +/// Structured message part (text or image) +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum MessagePart { + Text { text: String }, + Image { image_id: String }, +} + /// Request body for sending a message and streaming AI response #[derive(Debug, Deserialize)] pub struct StreamRequest { pub content: String, + #[serde(default)] + pub parts: Option>, pub model_key: String, pub reasoning_effort: Option, pub reasoning_budget_tokens: Option, @@ -152,7 +162,7 @@ fn merge_reasoning_with_priority( async fn execute_tool_by_name(db: &sqlx::PgPool, user_id: Uuid, full_tool_name: &str, input: serde_json::Value) -> Result { use crate::types::ToolSourceKind; - let mut tool: Option = None; + let mut tool: Option; let mut function_name: Option = None; let mut function_id: Option = None; @@ -220,6 +230,7 @@ async fn execute_tool_by_name(db: &sqlx::PgPool, user_id: Uuid, full_tool_name: settings, timeout_ms: Some(30000), function_name: function_name.map(|s| s.to_string()), + db: Some(std::sync::Arc::new(db.clone())), }; match tool.source_kind { @@ -298,8 +309,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] Chat verified"); - let model = sqlx::query_as::<_, crate::types::AiModel>("SELECT * FROM models WHERE model_id = $1") .bind(&req.model_key) .fetch_optional(&state.db) @@ -314,8 +323,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] Model verified"); - let model_config = sqlx::query_as::<_, ModelConfig>("SELECT * FROM model_configs WHERE owner_id = $1 AND stable_key = $2") .bind(user.id) .bind(&req.model_key) @@ -324,8 +331,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook .ok() .flatten(); - eprintln!("[STREAM] Model config: {:?}", model_config.as_ref().map(|mc| &mc.name)); - let reasoning_details = crate::types::ReasoningDetails { effort: req.reasoning_effort.clone(), budget_tokens: req.reasoning_budget_tokens.map(|b| b as i32), @@ -333,15 +338,18 @@ pub async fn stream_completion(State(state): State>, cookies: Cook let usage_details = crate::types::UsageDetails::default(); let cost_details = crate::types::CostDetails::default(); + let content_parts_json = req.parts.as_ref().map(|parts| serde_json::to_value(parts).ok()).flatten(); + let user_message = sqlx::query_as::<_, Message>( r#" - INSERT INTO messages (chat_id, role, content, model_id, reasoning_details, usage_details, cost_details) - VALUES ($1, 'user', $2, $3, $4, $5, $6) + INSERT INTO messages (chat_id, role, content, content_parts, model_id, reasoning_details, usage_details, cost_details) + VALUES ($1, 'user', $2, $3, $4, $5, $6, $7) RETURNING * "#, ) .bind(chat_id) .bind(&req.content) + .bind(content_parts_json) .bind(model.id) .bind(sqlx::types::Json(reasoning_details)) .bind(sqlx::types::Json(usage_details)) @@ -357,8 +365,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] User message saved"); - let messages = sqlx::query_as::<_, Message>("SELECT * FROM messages WHERE chat_id = $1 ORDER BY created_at ASC") .bind(chat_id) .fetch_all(&state.db) @@ -372,8 +378,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] Messages fetched"); - let engine = ai::get(); let engine_read = engine.read().await; @@ -386,8 +390,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] Model verified"); - let provider = match engine_read.get_provider(&omni_model.provider_name).await { Some(p) => p, None => { @@ -397,17 +399,45 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } }; - eprintln!("[STREAM] Provider verified"); + let omni_messages: Vec = { + let mut result = Vec::new(); + for m in messages.iter().filter(|m| m.role == "user" || m.role == "assistant") { + let parts = if let Some(content_parts_json) = &m.content_parts { + if let Ok(stored_parts) = serde_json::from_value::>(content_parts_json.clone()) { + let mut omni_parts = Vec::new(); + for part in stored_parts { + match part { + MessagePart::Text { text } => { + omni_parts.push(ContentPart::Text(text)); + } + MessagePart::Image { image_id } => { + if let Ok(uuid) = uuid::Uuid::parse_str(&image_id) { + if let Ok(Some((data, mime))) = crate::utils::images::get_image(&state.db, uuid).await { + use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; + let b64 = BASE64.encode(&data); + let data_uri = format!("data:{};base64,{}", mime, b64); + omni_parts.push(ContentPart::ImageUrl { url: data_uri, mime: Some(mime) }); + } + } + } + } + } + omni_parts + } else { + vec![ContentPart::Text(m.content.clone())] + } + } else { + vec![ContentPart::Text(m.content.clone())] + }; - let omni_messages: Vec = messages - .iter() - .filter(|m| m.role == "user" || m.role == "assistant") - .map(|m| OmniMessage { - role: if m.role == "user" { Role::User } else { Role::Assistant }, - parts: vec![ContentPart::Text(m.content.clone())], - name: None, - }) - .collect(); + result.push(OmniMessage { + role: if m.role == "user" { Role::User } else { Role::Assistant }, + parts, + name: None, + }); + } + result + }; eprintln!("[STREAM] Messages built"); let mut ir = ChatRequestIR::default(); @@ -439,8 +469,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook .fetch_all(&state.db) .await; - eprintln!("[STREAM] Tools query result: {:?}", tools.as_ref().map(|t| t.len()).map_err(|e| e.to_string())); - if let Ok(tools) = tools { eprintln!("[STREAM] Found {} tools from DB", tools.len()); let tool_ids: Vec = tools.iter().map(|t| t.id).collect(); @@ -475,8 +503,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook } } - eprintln!("[STREAM] Chat request: {:?}", ir); - let omni_messages_for_stream = ir.messages.clone(); let ir_for_stream = ir.clone(); @@ -560,7 +586,6 @@ pub async fn stream_completion(State(state): State>, cookies: Cook let mut event_count = 0; while let Some(event) = upstream.next().await { event_count += 1; - eprintln!("[STREAM] Event #{}: {:?}", event_count, &event); match event { StreamEvent::TextDelta { content } => { if let Some(start) = reasoning_start.take() { @@ -774,7 +799,7 @@ pub async fn stream_completion(State(state): State>, cookies: Cook current_messages.push(OmniMessage { role: Role::Tool, parts: vec![ContentPart::Text(result_text)], - name: Some(format!("{}:{}", tool_name, call_id)), + name: Some(call_id.clone()), }); all_tool_executions.push((call_id, tool_name, args, output, error, exec_ms, tool_id, function_id)); diff --git a/src/types/chat.rs b/src/types/chat.rs index bb52c47..0d7e414 100644 --- a/src/types/chat.rs +++ b/src/types/chat.rs @@ -172,6 +172,7 @@ pub struct Message { pub chat_id: Uuid, pub role: String, pub content: String, + pub content_parts: Option, pub reasoning_content: Option, pub model_id: Option, pub cost_details: sqlx::types::Json, @@ -345,6 +346,8 @@ pub struct ChatMessageResponse { pub content: String, pub reasoning_content: Option, pub model_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content_parts: Option, pub cost_details: CostDetails, pub usage_details: UsageDetails, pub reasoning_details: ReasoningDetails, @@ -361,6 +364,7 @@ impl From for ChatMessageResponse { content: m.content, reasoning_content: m.reasoning_content, model_id: m.model_id, + content_parts: m.content_parts, cost_details: m.cost_details.0, usage_details: m.usage_details.0, reasoning_details: m.reasoning_details.0, diff --git a/src/utils/images.rs b/src/utils/images.rs new file mode 100644 index 0000000..610a1a2 --- /dev/null +++ b/src/utils/images.rs @@ -0,0 +1,210 @@ +//! Image Storage Module +//! +//! Provides configurable storage backends for images (database or filesystem). +//! Set `IMAGE_STORAGE_TYPE` env var to "database" or "file" (default: "database"). +//! For file storage, set `IMAGE_STORAGE_PATH` (default: "./uploads/images"). + +use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use sqlx::PgPool; +use std::path::PathBuf; +use tokio::fs; +use uuid::Uuid; + +/// Storage type for images +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StorageType { + Database, + File, +} + +impl StorageType { + pub fn from_env() -> Self { + match std::env::var("IMAGE_STORAGE_TYPE").as_deref() { + Ok("file") | Ok("filesystem") => StorageType::File, + _ => StorageType::Database, + } + } +} + +/// Get the base path for file storage +pub fn storage_path() -> PathBuf { + std::env::var("IMAGE_STORAGE_PATH") + .map(PathBuf::from) + .unwrap_or_else(|_| PathBuf::from("./uploads/images")) +} + +/// Generate a URL path for an image (relative, works behind any reverse proxy) +pub fn image_url(id: Uuid) -> String { + format!("/api/v1/images/{}", id) +} +/// Stored image metadata +#[derive(Debug, Clone)] +pub struct StoredImage { + pub id: Uuid, + pub mime_type: String, + pub size_bytes: i64, +} + +/// Store an image from base64 data +/// +/// Returns the stored image metadata including the generated UUID. +pub async fn store_image(db: &PgPool, data: &[u8], mime_type: &str, user_id: Option, source: Option<&str>) -> Result { + let storage_type = StorageType::from_env(); + let id = Uuid::new_v4(); + let size_bytes = i64::try_from(data.len()).map_err(|_| "Image size too large to represent in 64-bit".to_string())?; + + match storage_type { + StorageType::Database => { + sqlx::query( + r#" + INSERT INTO images (id, data, mime_type, size_bytes, user_id, source) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(id) + .bind(data) + .bind(mime_type) + .bind(size_bytes) + .bind(user_id) + .bind(source) + .execute(db) + .await + .map_err(|e| format!("Failed to store image in database: {e}"))?; + } + StorageType::File => { + let base_path = storage_path(); + fs::create_dir_all(&base_path).await.map_err(|e| format!("Failed to create storage directory: {e}"))?; + + let extension = mime_to_extension(mime_type); + let filename = format!("{id}.{extension}"); + let file_path = base_path.join(&filename); + + fs::write(&file_path, data).await.map_err(|e| format!("Failed to write image file: {e}"))?; + + let relative_path = filename; + + let db_result = sqlx::query( + r#" + INSERT INTO images (id, file_path, mime_type, size_bytes, user_id, source) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + ) + .bind(id) + .bind(&relative_path) + .bind(mime_type) + .bind(size_bytes) + .bind(user_id) + .bind(source) + .execute(db) + .await; + + if let Err(e) = db_result { + let _ = fs::remove_file(&file_path).await; + return Err(format!("Failed to store image metadata: {e}")); + } + } + } + + Ok(StoredImage { + id, + mime_type: mime_type.to_string(), + size_bytes, + }) +} + +/// Store an image from a base64 data URI (e.g., "data:image/png;base64,...") +pub async fn store_from_data_uri(db: &PgPool, data_uri: &str, user_id: Option, source: Option<&str>) -> Result { + let (mime_type, data) = parse_data_uri(data_uri)?; + store_image(db, &data, &mime_type, user_id, source).await +} + +/// Retrieve an image by ID +/// +/// Returns (data, mime_type) or None if not found. +pub async fn get_image(db: &PgPool, id: Uuid) -> Result, String)>, String> { + #[derive(sqlx::FromRow)] + struct ImageRow { + data: Option>, + file_path: Option, + mime_type: String, + } + + let row: Option = sqlx::query_as( + r#" + SELECT data, file_path, mime_type FROM images WHERE id = $1 + "#, + ) + .bind(id) + .fetch_optional(db) + .await + .map_err(|e| format!("Failed to fetch image: {e}"))?; + + match row { + Some(img) => { + let data = if let Some(data) = img.data { + data + } else if let Some(file_path) = img.file_path { + let full_path = storage_path().join(&file_path); + fs::read(&full_path).await.map_err(|e| format!("Failed to read image file: {e}"))? + } else { + return Err("Image data not found in database or file".to_string()); + }; + Ok(Some((data, img.mime_type))) + } + None => Ok(None), + } +} + +/// Parse a base64 data URI into (mime_type, decoded_bytes) +fn parse_data_uri(data_uri: &str) -> Result<(String, Vec), String> { + if !data_uri.starts_with("data:") { + return Err("Invalid data URI: must start with 'data:'".to_string()); + } + + let without_prefix = &data_uri[5..]; + let parts: Vec<&str> = without_prefix.splitn(2, ',').collect(); + + if parts.len() != 2 { + return Err("Invalid data URI format".to_string()); + } + + let header = parts[0]; + let data = parts[1]; + + let mime_type = if header.contains(';') { + let extracted = header.split(';').next().unwrap_or("image/png"); + if extracted.is_empty() { + "application/octet-stream".to_string() + } else { + extracted.to_string() + } + } else { + if header.is_empty() { + "application/octet-stream".to_string() + } else { + header.to_string() + } + }; + + let decoded = BASE64.decode(data).map_err(|e| format!("Failed to decode base64: {e}"))?; + + Ok((mime_type, decoded)) +} + +/// Convert MIME type to file extension +fn mime_to_extension(mime: &str) -> &'static str { + match mime { + "image/png" => "png", + "image/jpeg" | "image/jpg" => "jpg", + "image/gif" => "gif", + "image/webp" => "webp", + "image/svg+xml" => "svg", + "image/bmp" => "bmp", + _ => "bin", + } +} + +/// Check if an image URL is a base64 data URI that needs to be uploaded +pub fn is_data_uri(url: &str) -> bool { + url.starts_with("data:") +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 4278216..562b867 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,5 +1,6 @@ pub mod auth; pub mod encryption; +pub mod images; pub mod oauth; pub mod providers; pub mod response; diff --git a/src/utils/tools/builtin/imagegen.rs b/src/utils/tools/builtin/imagegen.rs new file mode 100644 index 0000000..bb6f0fa --- /dev/null +++ b/src/utils/tools/builtin/imagegen.rs @@ -0,0 +1,606 @@ +use async_trait::async_trait; +use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use reqwest::{Client, header::CONTENT_TYPE}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use std::time::Duration; + +use super::super::executor::{ToolContext, ToolError, ToolExecutor}; +use crate::utils::images::{image_url, store_from_data_uri}; + +const OPENAI_GENERATIONS_URL: &str = "https://api.openai.com/v1/images/generations"; +const OPENAI_EDITS_URL: &str = "https://api.openai.com/v1/images/edits"; +const REPLICATE_API_BASE: &str = "https://api.replicate.com/v1/models"; +const GOOGLE_GEMINI_GENERATIONS_URL: &str = "https://generativelanguage.googleapis.com/v1beta/openai/images/generations"; +const GOOGLE_GEMINI_CONTENT_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateImageInput { + pub prompt: String, + #[serde(default = "default_size")] + pub size: String, + #[serde(default = "default_quality")] + pub quality: String, +} + +fn default_size() -> String { + "1024x1024".to_string() +} + +fn default_quality() -> String { + "standard".to_string() +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EditImageInput { + pub image_url: String, + pub prompt: String, +} + +#[derive(Debug, Serialize)] +pub struct ImageResult { + pub success: bool, + pub image_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIImageResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIImageData { + #[serde(default)] + url: Option, + #[serde(default)] + b64_json: Option, +} + +#[derive(Debug, Deserialize)] +struct ReplicateResponse { + output: Option, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiContentResponse { + candidates: Option>, +} + +#[derive(Debug, Deserialize)] +struct GeminiCandidate { + content: Option, + #[serde(rename = "finishReason")] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiContent { + parts: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiPart { + #[serde(default)] + text: Option, + #[serde(rename = "inlineData")] + inline_data: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiInlineData { + #[serde(rename = "mimeType")] + mime_type: String, + data: String, +} + +pub struct ImageGenExecutor { + client: Client, +} + +impl ImageGenExecutor { + pub fn new() -> Result { + let client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .map_err(|e| ToolError::Internal(format!("Failed to create HTTP client: {e}")))?; + + Ok(Self { client }) + } + + /// Process an image URL, uploading data URIs to CDN if database is available + async fn process_image_url(&self, url: String, ctx: &ToolContext) -> String { + if !url.starts_with("data:") { + return url; + } + + let Some(db) = &ctx.db else { + return url; + }; + + match store_from_data_uri(db, &url, ctx.user_id, Some("imagegen")).await { + Ok(stored) => image_url(stored.id), + Err(e) => { + eprintln!("[IMAGEGEN] Failed to upload image to CDN: {}", e); + url // Return original on failure + } + } + } + + async fn generate_openai(&self, input: &GenerateImageInput, api_key: &str, model: &str) -> Result { + let payload = json!({ + "model": model, + "prompt": input.prompt, + "n": 1, + "moderation": "low", + }); + + let response = self + .client + .post(OPENAI_GENERATIONS_URL) + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("OpenAI API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("OpenAI API error {status}: {error_body}"))); + } + + let openai_response: OpenAIImageResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse OpenAI response: {e}")))?; + + let image_url = openai_response + .data + .first() + .and_then(|d| d.url.clone().or_else(|| d.b64_json.clone().map(|b| format!("data:image/png;base64,{b}")))) + .ok_or_else(|| ToolError::ExecutionFailed("No image URL in response".to_string()))?; + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn edit_openai(&self, input: &EditImageInput, api_key: &str, model: &str, ctx: &ToolContext) -> Result { + let (image_bytes, mime_type) = self.download_image(&input.image_url, ctx).await?; + let file_name = match mime_type.as_str() { + "image/png" => "image.png", + "image/jpeg" => "image.jpg", + "image/webp" => "image.webp", + "image/gif" => "image.gif", + _ => "image", + }; + let mime_type = if reqwest::multipart::Part::bytes(Vec::new()).mime_str(&mime_type).is_ok() { + mime_type + } else { + "application/octet-stream".to_string() + }; + + let form = reqwest::multipart::Form::new() + .part( + "image", + reqwest::multipart::Part::bytes(image_bytes).file_name(file_name).mime_str(&mime_type).unwrap(), + ) + .text("prompt", input.prompt.clone()) + .text("model", model.to_string()) + .text("n", "1"); + + let response = self + .client + .post(OPENAI_EDITS_URL) + .header("Authorization", format!("Bearer {api_key}")) + .multipart(form) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("OpenAI API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("OpenAI API error {status}: {error_body}"))); + } + + let openai_response: OpenAIImageResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse OpenAI response: {e}")))?; + + let image_url = openai_response + .data + .first() + .and_then(|d| d.url.clone().or_else(|| d.b64_json.clone().map(|b| format!("data:image/png;base64,{b}")))) + .ok_or_else(|| ToolError::ExecutionFailed("No image URL in response".to_string()))?; + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn generate_replicate(&self, input: &GenerateImageInput, api_key: &str, model: &str) -> Result { + let url = format!("{REPLICATE_API_BASE}/{model}/predictions"); + + let payload = json!({ + "input": { + "prompt": input.prompt + } + }); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .header("Prefer", "wait=60") + .json(&payload) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("Replicate API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("Replicate API error {status}: {error_body}"))); + } + + let replicate_response: ReplicateResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse Replicate response: {e}")))?; + + if let Some(error) = replicate_response.error { + return Err(ToolError::ExecutionFailed(error)); + } + + let image_url = match replicate_response.output { + Some(Value::Array(arr)) => arr.first().and_then(|v| v.as_str()).map(String::from), + Some(Value::String(s)) => Some(s), + _ => None, + }; + + let image_url = image_url.ok_or_else(|| ToolError::ExecutionFailed("No image URL in Replicate response".to_string()))?; + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn edit_replicate(&self, input: &EditImageInput, api_key: &str, model: &str, ctx: &ToolContext) -> Result { + let url = format!("{REPLICATE_API_BASE}/{model}/predictions"); + + let (image_bytes, mime_type) = self.download_image(&input.image_url, ctx).await?; + let image_b64 = BASE64.encode(&image_bytes); + let image_data_uri = format!("data:{mime_type};base64,{image_b64}"); + + let image_param = get_replicate_image_param(model); + + let payload = json!({ + "input": { + "prompt": input.prompt, + image_param: image_data_uri + } + }); + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .header("Prefer", "wait=60") + .json(&payload) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("Replicate API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("Replicate API error {status}: {error_body}"))); + } + + let replicate_response: ReplicateResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse Replicate response: {e}")))?; + + if let Some(error) = replicate_response.error { + return Err(ToolError::ExecutionFailed(error)); + } + + let image_url = match replicate_response.output { + Some(Value::Array(arr)) => arr.first().and_then(|v| v.as_str()).map(String::from), + Some(Value::String(s)) => Some(s), + _ => None, + }; + + let image_url = image_url.ok_or_else(|| ToolError::ExecutionFailed("No image URL in Replicate response".to_string()))?; + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn generate_google(&self, input: &GenerateImageInput, api_key: &str, model: &str) -> Result { + let payload = json!({ + "prompt": input.prompt, + "model": model, + "response_format": "b64_json", + "n": 1 + }); + + let url = GOOGLE_GEMINI_GENERATIONS_URL; + + let response = self + .client + .post(url) + .header("x-goog-api-key", api_key) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("Google API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("Google API error {status}: {error_body}"))); + } + + let google_response: OpenAIImageResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse Google response: {e}")))?; + + let image_url = google_response + .data + .first() + .and_then(|d| d.url.clone().or_else(|| d.b64_json.clone().map(|b| format!("data:image/png;base64,{b}")))) + .ok_or_else(|| ToolError::ExecutionFailed("No image data in response".to_string()))?; + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn edit_google(&self, input: &EditImageInput, api_key: &str, model: &str, ctx: &ToolContext) -> Result { + let url = format!("{GOOGLE_GEMINI_CONTENT_URL}/{model}:generateContent"); + + let (image_bytes, mime_type) = self.download_image(&input.image_url, ctx).await?; + let image_b64 = BASE64.encode(&image_bytes); + + let payload = json!({ + "contents": [{ + "parts": [ + { "text": input.prompt }, + { + "inline_data": { + "mime_type": mime_type, + "data": image_b64 + } + } + ] + }] + }); + + let response = self + .client + .post(&url) + .header("x-goog-api-key", api_key) + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("Google API request failed: {e}")))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response.text().await.unwrap_or_default(); + return Err(ToolError::HttpError(format!("Google API error {status}: {error_body}"))); + } + + let gemini_response: GeminiContentResponse = response + .json() + .await + .map_err(|e| ToolError::ExecutionFailed(format!("Failed to parse Google response: {e}")))?; + + let candidates = gemini_response + .candidates + .ok_or_else(|| ToolError::ExecutionFailed("No candidates in response".to_string()))?; + + let candidate = candidates + .first() + .ok_or_else(|| ToolError::ExecutionFailed("No candidates in response".to_string()))?; + + if candidate.content.is_none() { + let finish_reason = candidate.finish_reason.as_deref().unwrap_or("UNKNOWN"); + return Err(ToolError::ExecutionFailed(match finish_reason { + "SAFETY" => "Image generation blocked by safety filters".to_string(), + "RECITATION" => "Image generation blocked due to potential copyright issues".to_string(), + "PROHIBITED_CONTENT" => "Image generation blocked due to prohibited content".to_string(), + _ => format!("Image generation failed: {finish_reason}"), + })); + } + + let content = candidate.content.as_ref().unwrap(); + let image_data = content + .parts + .iter() + .find_map(|p| p.inline_data.as_ref()) + .ok_or_else(|| ToolError::ExecutionFailed("No image data in response".to_string()))?; + + let image_url = format!("data:{};base64,{}", image_data.mime_type, image_data.data); + + Ok(json!(ImageResult { + success: true, + image_url: Some(image_url), + error: None, + })) + } + + async fn download_image(&self, url: &str, ctx: &ToolContext) -> Result<(Vec, String), ToolError> { + if url.starts_with("data:") { + let mut parts = url.splitn(2, ','); + let meta = parts.next().unwrap_or_default(); + let data_part = parts.next().ok_or_else(|| ToolError::InvalidInput("Invalid data URL format".to_string()))?; + let mime_type = meta + .strip_prefix("data:") + .and_then(|value| value.split(';').next()) + .filter(|value| !value.is_empty()) + .unwrap_or("application/octet-stream") + .to_string(); + let bytes = BASE64 + .decode(data_part) + .map_err(|e| ToolError::InvalidInput(format!("Failed to decode base64: {e}")))?; + return Ok((bytes, mime_type)); + } + + if let Some(id_str) = url.strip_prefix("/api/v1/images/") { + let db = ctx.db.as_ref().ok_or_else(|| ToolError::Internal("Database not available".to_string()))?; + let id = uuid::Uuid::parse_str(id_str).map_err(|e| ToolError::InvalidInput(format!("Invalid image ID: {e}")))?; + let (data, mime) = crate::utils::images::get_image(db, id) + .await + .map_err(|e| ToolError::Internal(format!("Failed to fetch image: {e}")))? + .ok_or_else(|| ToolError::InvalidInput("Image not found".to_string()))?; + let mime_type = if mime.is_empty() { "application/octet-stream".to_string() } else { mime }; + return Ok((data, mime_type)); + } + + let response = self + .client + .get(url) + .send() + .await + .map_err(|e| ToolError::HttpError(format!("Failed to download image: {e}")))?; + + let status = response.status(); + if !status.is_success() { + return Err(ToolError::HttpError(format!("Failed to download image: HTTP {status}"))); + } + + let header_mime = response + .headers() + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.split(';').next()) + .map(|value| value.to_string()); + + let bytes = response + .bytes() + .await + .map(|b| b.to_vec()) + .map_err(|e| ToolError::HttpError(format!("Failed to read image bytes: {e}")))?; + + let detected_mime = infer::get(&bytes).map(|kind| kind.mime_type().to_string()); + let mime_type = header_mime.or(detected_mime).unwrap_or_else(|| "application/octet-stream".to_string()); + + Ok((bytes, mime_type)) + } +} + +fn get_replicate_image_param(model: &str) -> &'static str { + let model_lower = model.to_lowercase(); + if model_lower.contains("nano-banana") { + "image_input" + } else if model_lower.contains("flux-redux") { + "redux_image" + } else if model_lower.contains("flux-kontext") { + "input_image" + } else { + "image" + } +} + +#[async_trait] +impl ToolExecutor for ImageGenExecutor { + async fn execute(&self, input: Value, ctx: &ToolContext) -> Result { + let function = match ctx.function_name.as_deref() { + Some("generate") | Some("imagegen") | None => "generate", + Some("edit") => "edit", + Some(other) => return Err(ToolError::ExecutionFailed(format!("Unknown function: {other}"))), + }; + + let api_key = ctx + .settings + .get("api_key") + .and_then(Value::as_str) + .ok_or_else(|| ToolError::MissingSetting("api_key".to_string()))?; + + let provider = ctx.settings.get("provider").and_then(Value::as_str).unwrap_or("openai"); + + let model = ctx.settings.get("model").and_then(Value::as_str).unwrap_or_else(|| match provider { + "openai" => "dall-e-3", + "google" => "imagen-3.0-generate-002", + _ => "black-forest-labs/flux-schnell", + }); + + let result = match (function, provider) { + ("generate", "openai") => { + let gen_input: GenerateImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.generate_openai(&gen_input, api_key, model).await? + } + ("generate", "replicate") => { + let gen_input: GenerateImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.generate_replicate(&gen_input, api_key, model).await? + } + ("generate", "google") => { + let gen_input: GenerateImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.generate_google(&gen_input, api_key, model).await? + } + ("edit", "openai") => { + let edit_input: EditImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.edit_openai(&edit_input, api_key, model, ctx).await? + } + ("edit", "replicate") => { + let edit_input: EditImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.edit_replicate(&edit_input, api_key, model, ctx).await? + } + ("edit", "google") => { + let edit_input: EditImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.edit_google(&edit_input, api_key, model, ctx).await? + } + _ => { + let gen_input: GenerateImageInput = serde_json::from_value(input).map_err(|e| ToolError::InvalidInput(format!("Invalid input: {e}")))?; + self.generate_openai(&gen_input, api_key, model).await? + } + }; + + // Process image URLs in the result - upload data URIs to CDN + if let Some(image_url_val) = result.get("image_url").and_then(|v| v.as_str()) { + let processed_url = self.process_image_url(image_url_val.to_string(), ctx).await; + if let Some(mut result_obj) = result.as_object().cloned() { + result_obj.insert("image_url".to_string(), Value::String(processed_url.clone())); + result_obj.insert( + "message".to_string(), + Value::String(format!( + "Image was successfully generated and uploaded to the CDN, you can show the image by using the following markdown: ![]({}). No other action is needed.", + processed_url + )), + ); + return Ok(Value::Object(result_obj)); + } + } + + Ok(result) + } + + fn name(&self) -> &str { + "imagegen" + } +} diff --git a/src/utils/tools/builtin/mod.rs b/src/utils/tools/builtin/mod.rs index 4eb92da..f5d9cea 100644 --- a/src/utils/tools/builtin/mod.rs +++ b/src/utils/tools/builtin/mod.rs @@ -1,5 +1,6 @@ //! Built-in tool implementations organized by category +pub mod imagegen; pub mod websearch; use super::executor::{ToolError, ToolExecutor}; @@ -7,13 +8,14 @@ use super::executor::{ToolError, ToolExecutor}; /// Get a builtin executor by ID /// /// # Arguments -/// * `builtin_id` - The builtin tool identifier (e.g., "websearch") +/// * `builtin_id` - The builtin tool identifier (e.g., "websearch", "imagegen") /// /// # Errors /// Returns `ToolError::NotFound` if the builtin ID is unknown pub fn get_builtin_executor(builtin_id: &str) -> Result, ToolError> { match builtin_id { "websearch" => Ok(Box::new(websearch::WebsearchExecutor::new()?)), + "imagegen" => Ok(Box::new(imagegen::ImageGenExecutor::new()?)), _ => Err(ToolError::NotFound(format!("Unknown builtin tool: {builtin_id}"))), } } diff --git a/src/utils/tools/executor.rs b/src/utils/tools/executor.rs index 478bd42..7bf532a 100644 --- a/src/utils/tools/executor.rs +++ b/src/utils/tools/executor.rs @@ -1,5 +1,7 @@ use async_trait::async_trait; use serde_json::Value; +use sqlx::PgPool; +use std::sync::Arc; use thiserror::Error; use uuid::Uuid; @@ -36,12 +38,26 @@ pub enum ToolError { Internal(String), } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ToolContext { pub user_id: Option, pub settings: Value, pub timeout_ms: Option, pub function_name: Option, + /// Database pool for tools that need storage access (e.g., imagegen) + pub db: Option>, +} + +impl std::fmt::Debug for ToolContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolContext") + .field("user_id", &self.user_id) + .field("settings", &self.settings) + .field("timeout_ms", &self.timeout_ms) + .field("function_name", &self.function_name) + .field("db", &self.db.as_ref().map(|_| "")) + .finish() + } } impl Default for ToolContext { @@ -51,6 +67,7 @@ impl Default for ToolContext { settings: Value::Object(serde_json::Map::new()), timeout_ms: Some(30000), function_name: None, + db: None, } } }