diff --git a/src/globals.ts b/src/globals.ts index 4d6e327e4..d827d8cdd 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -63,6 +63,7 @@ export const AI21: string = 'ai21'; export const BEDROCK: string = 'bedrock'; export const GROQ: string = 'groq'; export const SEGMIND: string = 'segmind'; +export const MODELSLAB: string = 'modelslab'; export const JINA: string = 'jina'; export const FIREWORKS_AI: string = 'fireworks-ai'; export const WORKERS_AI: string = 'workers-ai'; @@ -136,6 +137,7 @@ export const VALID_PROVIDERS = [ BEDROCK, GROQ, SEGMIND, + MODELSLAB, JINA, FIREWORKS_AI, WORKERS_AI, diff --git a/src/providers/index.ts b/src/providers/index.ts index 2cd5355f8..ac858f5fb 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -20,6 +20,7 @@ import OllamaAPIConfig from './ollama'; import { ProviderConfigs } from './types'; import GroqConfig from './groq'; import SegmindConfig from './segmind'; +import ModelsLabConfig from './modelslab'; import JinaConfig from './jina'; import FireworksAIConfig from './fireworks-ai'; import WorkersAiConfig from './workers-ai'; @@ -97,6 +98,7 @@ const Providers: { [key: string]: ProviderConfigs } = { bedrock: BedrockConfig, groq: GroqConfig, segmind: SegmindConfig, + modelslab: ModelsLabConfig, jina: JinaConfig, 'fireworks-ai': FireworksAIConfig, 'workers-ai': WorkersAiConfig, diff --git a/src/providers/modelslab/api.ts b/src/providers/modelslab/api.ts new file mode 100644 index 000000000..3991922db --- /dev/null +++ b/src/providers/modelslab/api.ts @@ -0,0 +1,18 @@ +import { ProviderAPIConfig } from '../types'; + +const ModelsLabAPIConfig: ProviderAPIConfig = { + getBaseURL: () => 'https://modelslab.com/api/v6', + headers: () => { + return { 'Content-Type': 'application/json' }; + }, + getEndpoint: ({ fn }) => { + switch (fn) { + case 'imageGenerate': + return '/images/text2img'; + default: + return ''; + } + }, +}; + +export default ModelsLabAPIConfig; diff --git a/src/providers/modelslab/imageGenerate.ts b/src/providers/modelslab/imageGenerate.ts new file mode 100644 index 000000000..1d076526b --- /dev/null +++ b/src/providers/modelslab/imageGenerate.ts @@ -0,0 +1,151 @@ +import { MODELSLAB } from '../../globals'; +import { Options } from '../../types/requestBody'; +import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types'; +import { + generateErrorResponse, + generateInvalidProviderResponseError, +} from '../utils'; + +export const ModelsLabImageGenerateConfig: ProviderConfig = { + // Inject the API key into the request body (ModelsLab requires "key" in JSON body) + _key: { + param: 'key', + required: true, + default: (_params: any, providerOptions: Options) => providerOptions.apiKey, + }, + prompt: { + param: 'prompt', + required: true, + }, + model: { + param: 'model_id', + default: 'flux', + }, + n: { + param: 'samples', + default: 1, + min: 1, + max: 4, + }, + size: [ + { + param: 'width', + transform: (params: any) => { + if (!params.size) return 512; + return parseInt(params.size.toLowerCase().split('x')[0]); + }, + min: 256, + }, + { + param: 'height', + transform: (params: any) => { + if (!params.size) return 512; + return parseInt(params.size.toLowerCase().split('x')[1]); + }, + min: 256, + }, + ], + steps: { + param: 'num_inference_steps', + default: 30, + min: 1, + max: 50, + }, + guidance_scale: { + param: 'guidance_scale', + default: 7.5, + min: 1, + max: 20, + }, + seed: { + param: 'seed', + }, + negative_prompt: { + param: 'negative_prompt', + }, + safety_checker: { + param: 'safety_checker', + default: 'no', + }, + webhook: { + param: 'webhook', + }, + track_id: { + param: 'track_id', + }, +}; + +interface ModelsLabImageGenerateSuccessResponse { + status: 'success'; + generationTime: number; + id: number; + output: string[]; + nsfw_content_detected: string[] | null; + meta: Record; +} + +interface ModelsLabImageGenerateProcessingResponse { + status: 'processing'; + id: number; + output: string[] | null; + fetch_result: string; + eta: number; + message: string; + messege?: string; +} + +interface ModelsLabImageGenerateErrorResponse { + status: 'error'; + message: string; + messege?: string; // Legacy typo in older API versions +} + +type ModelsLabImageGenerateResponse = + | ModelsLabImageGenerateSuccessResponse + | ModelsLabImageGenerateProcessingResponse + | ModelsLabImageGenerateErrorResponse; + +export const ModelsLabImageGenerateResponseTransform: ( + response: ModelsLabImageGenerateResponse, + responseStatus: number +) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 || response.status === 'error') { + const message = + ('message' in response && response.message) || + ('messege' in response && response.messege) || + 'Unknown error occurred'; + return generateErrorResponse( + { + message: message as string, + type: 'ModelsLabError', + param: null, + code: String(responseStatus), + }, + MODELSLAB + ); + } + + if (response.status === 'processing') { + // Generation is queued; return a processing response with the fetch URL. + // Consumers should use webhooks or poll `response.fetch_result` for the result. + return generateErrorResponse( + { + message: `Image generation is processing. Poll fetch URL: ${response.fetch_result} | ETA: ${response.eta}s`, + type: 'ModelsLabProcessing', + param: null, + code: '202', + }, + MODELSLAB + ); + } + + if (response.status === 'success' && response.output?.length) { + return { + created: Math.floor(Date.now() / 1000), + data: response.output.map((url) => ({ url })), + provider: MODELSLAB, + } as ImageGenerateResponse; + } + + return generateInvalidProviderResponseError(response, MODELSLAB); +}; diff --git a/src/providers/modelslab/index.ts b/src/providers/modelslab/index.ts new file mode 100644 index 000000000..52bf01495 --- /dev/null +++ b/src/providers/modelslab/index.ts @@ -0,0 +1,16 @@ +import { ProviderConfigs } from '../types'; +import ModelsLabAPIConfig from './api'; +import { + ModelsLabImageGenerateConfig, + ModelsLabImageGenerateResponseTransform, +} from './imageGenerate'; + +const ModelsLabConfig: ProviderConfigs = { + api: ModelsLabAPIConfig, + imageGenerate: ModelsLabImageGenerateConfig, + responseTransforms: { + imageGenerate: ModelsLabImageGenerateResponseTransform, + }, +}; + +export default ModelsLabConfig;