Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export const AI21: string = 'ai21';
export const BEDROCK: string = 'bedrock';
export const GROQ: string = 'groq';
export const SEGMIND: string = 'segmind';
export const MODELSLAB: string = 'modelslab';
export const JINA: string = 'jina';
export const FIREWORKS_AI: string = 'fireworks-ai';
export const WORKERS_AI: string = 'workers-ai';
Expand Down Expand Up @@ -136,6 +137,7 @@ export const VALID_PROVIDERS = [
BEDROCK,
GROQ,
SEGMIND,
MODELSLAB,
JINA,
FIREWORKS_AI,
WORKERS_AI,
Expand Down
2 changes: 2 additions & 0 deletions src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import OllamaAPIConfig from './ollama';
import { ProviderConfigs } from './types';
import GroqConfig from './groq';
import SegmindConfig from './segmind';
import ModelsLabConfig from './modelslab';
import JinaConfig from './jina';
import FireworksAIConfig from './fireworks-ai';
import WorkersAiConfig from './workers-ai';
Expand Down Expand Up @@ -97,6 +98,7 @@ const Providers: { [key: string]: ProviderConfigs } = {
bedrock: BedrockConfig,
groq: GroqConfig,
segmind: SegmindConfig,
modelslab: ModelsLabConfig,
jina: JinaConfig,
'fireworks-ai': FireworksAIConfig,
'workers-ai': WorkersAiConfig,
Expand Down
18 changes: 18 additions & 0 deletions src/providers/modelslab/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ProviderAPIConfig } from '../types';

const ModelsLabAPIConfig: ProviderAPIConfig = {
getBaseURL: () => 'https://modelslab.com/api/v6',
headers: () => {
return { 'Content-Type': 'application/json' };
},
getEndpoint: ({ fn }) => {
switch (fn) {
case 'imageGenerate':
return '/images/text2img';
default:
return '';
}
},
};

export default ModelsLabAPIConfig;
151 changes: 151 additions & 0 deletions src/providers/modelslab/imageGenerate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import { MODELSLAB } from '../../globals';
import { Options } from '../../types/requestBody';
import { ErrorResponse, ImageGenerateResponse, ProviderConfig } from '../types';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';

export const ModelsLabImageGenerateConfig: ProviderConfig = {
// Inject the API key into the request body (ModelsLab requires "key" in JSON body)
_key: {
param: 'key',
required: true,
default: (_params: any, providerOptions: Options) => providerOptions.apiKey,
},
prompt: {
param: 'prompt',
required: true,
},
model: {
param: 'model_id',
default: 'flux',
},
n: {
param: 'samples',
default: 1,
min: 1,
max: 4,
},
size: [
{
param: 'width',
transform: (params: any) => {
if (!params.size) return 512;
return parseInt(params.size.toLowerCase().split('x')[0]);
},
min: 256,
},
{
param: 'height',
transform: (params: any) => {
if (!params.size) return 512;
return parseInt(params.size.toLowerCase().split('x')[1]);
},
min: 256,
},
],
steps: {
param: 'num_inference_steps',
default: 30,
min: 1,
max: 50,
},
guidance_scale: {
param: 'guidance_scale',
default: 7.5,
min: 1,
max: 20,
},
seed: {
param: 'seed',
},
negative_prompt: {
param: 'negative_prompt',
},
safety_checker: {
param: 'safety_checker',
default: 'no',
},
webhook: {
param: 'webhook',
},
track_id: {
param: 'track_id',
},
};

interface ModelsLabImageGenerateSuccessResponse {
status: 'success';
generationTime: number;
id: number;
output: string[];
nsfw_content_detected: string[] | null;
meta: Record<string, any>;
}

interface ModelsLabImageGenerateProcessingResponse {
status: 'processing';
id: number;
output: string[] | null;
fetch_result: string;
eta: number;
message: string;
messege?: string;
}

interface ModelsLabImageGenerateErrorResponse {
status: 'error';
message: string;
messege?: string; // Legacy typo in older API versions
}

type ModelsLabImageGenerateResponse =
| ModelsLabImageGenerateSuccessResponse
| ModelsLabImageGenerateProcessingResponse
| ModelsLabImageGenerateErrorResponse;

export const ModelsLabImageGenerateResponseTransform: (
response: ModelsLabImageGenerateResponse,
responseStatus: number
) => ImageGenerateResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200 || response.status === 'error') {
const message =
('message' in response && response.message) ||
('messege' in response && response.messege) ||
'Unknown error occurred';
return generateErrorResponse(
{
message: message as string,
type: 'ModelsLabError',
param: null,
code: String(responseStatus),
},
MODELSLAB
);
}

if (response.status === 'processing') {
// Generation is queued; return a processing response with the fetch URL.
// Consumers should use webhooks or poll `response.fetch_result` for the result.
return generateErrorResponse(
{
message: `Image generation is processing. Poll fetch URL: ${response.fetch_result} | ETA: ${response.eta}s`,
type: 'ModelsLabProcessing',
param: null,
code: '202',
},
MODELSLAB
);
}

if (response.status === 'success' && response.output?.length) {
return {
created: Math.floor(Date.now() / 1000),
data: response.output.map((url) => ({ url })),
provider: MODELSLAB,
} as ImageGenerateResponse;
}

return generateInvalidProviderResponseError(response, MODELSLAB);
};
16 changes: 16 additions & 0 deletions src/providers/modelslab/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { ProviderConfigs } from '../types';
import ModelsLabAPIConfig from './api';
import {
ModelsLabImageGenerateConfig,
ModelsLabImageGenerateResponseTransform,
} from './imageGenerate';

const ModelsLabConfig: ProviderConfigs = {
api: ModelsLabAPIConfig,
imageGenerate: ModelsLabImageGenerateConfig,
responseTransforms: {
imageGenerate: ModelsLabImageGenerateResponseTransform,
},
};

export default ModelsLabConfig;