diff --git a/src/providers/ai21/api.ts b/src/providers/ai21/api.ts index 6e1f94f8a..c24147982 100644 --- a/src/providers/ai21/api.ts +++ b/src/providers/ai21/api.ts @@ -12,10 +12,14 @@ const AI21APIConfig: ProviderAPIConfig = { const { model } = gatewayRequestBodyJSON; switch (fn) { case 'complete': { + // Legacy Jurassic-2 models use the model-specific completion endpoint return `/${model}/complete`; } case 'chatComplete': { - return `/${model}/chat`; + // Jamba models (jamba-1.5-*, jamba-1.6-*, jamba-instruct, etc.) + // use the OpenAI-compatible /chat/completions endpoint. + // Reference: https://docs.ai21.com/reference/jamba-1-6-api-ref + return `/chat/completions`; } case 'embed': { return `/embed`; diff --git a/src/providers/ai21/chatComplete.ts b/src/providers/ai21/chatComplete.ts index 304741086..eb5c9d9a6 100644 --- a/src/providers/ai21/chatComplete.ts +++ b/src/providers/ai21/chatComplete.ts @@ -1,125 +1,8 @@ import { AI21 } from '../../globals'; -import { Params, SYSTEM_MESSAGE_ROLES } from '../../types/requestBody'; -import { - ChatCompletionResponse, - ErrorResponse, - ProviderConfig, -} from '../types'; -import { - generateErrorResponse, - generateInvalidProviderResponseError, -} from '../utils'; +import { ErrorResponse } from '../types'; +import { generateErrorResponse } from '../utils'; import { AI21ErrorResponse } from './complete'; -export const AI21ChatCompleteConfig: ProviderConfig = { - messages: [ - { - param: 'messages', - required: true, - transform: (params: Params) => { - let inputMessages: any = []; - - if ( - params.messages?.[0]?.role && - SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role) - ) { - inputMessages = params.messages.slice(1); - } else if (params.messages) { - inputMessages = params.messages; - } - - return inputMessages.map((msg: any) => ({ - text: msg.content, - role: msg.role, - })); - }, - }, - { - param: 'system', - required: false, - transform: (params: Params) => { - if ( - params.messages?.[0]?.role && - SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role) - ) { - return params.messages?.[0].content; - } - }, - }, - ], - n: { - param: 'numResults', - default: 1, - }, - max_tokens: { - param: 'maxTokens', - default: 16, - }, - max_completion_tokens: { - param: 'maxTokens', - default: 16, - }, - minTokens: { - param: 'minTokens', - default: 0, - }, - temperature: { - param: 'temperature', - default: 0.7, - min: 0, - max: 1, - }, - top_p: { - param: 'topP', - default: 1, - }, - top_k: { - param: 'topKReturn', - default: 0, - }, - stop: { - param: 'stopSequences', - }, - presence_penalty: { - param: 'presencePenalty', - transform: (params: Params) => { - return { - scale: params.presence_penalty, - }; - }, - }, - frequency_penalty: { - param: 'frequencyPenalty', - transform: (params: Params) => { - return { - scale: params.frequency_penalty, - }; - }, - }, - countPenalty: { - param: 'countPenalty', - }, - frequencyPenalty: { - param: 'frequencyPenalty', - }, - presencePenalty: { - param: 'presencePenalty', - }, -}; - -interface AI21ChatCompleteResponse { - id: string; - outputs: { - text: string; - role: string; - finishReason: { - reason: string; - length: number | null; - sequence: string | null; - }; - }[]; -} - export const AI21ErrorResponseTransform: ( response: AI21ErrorResponse ) => ErrorResponse | undefined = (response) => { @@ -132,36 +15,3 @@ export const AI21ErrorResponseTransform: ( return undefined; }; - -export const AI21ChatCompleteResponseTransform: ( - response: AI21ChatCompleteResponse | AI21ErrorResponse, - responseStatus: number -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { - if (responseStatus !== 200) { - const errorResposne = AI21ErrorResponseTransform( - response as AI21ErrorResponse - ); - if (errorResposne) return errorResposne; - } - - if ('outputs' in response) { - return { - id: response.id, - object: 'chat.completion', - created: Math.floor(Date.now() / 1000), - model: '', - provider: AI21, - choices: response.outputs.map((o, index) => ({ - message: { - role: 'assistant', - content: o.text, - }, - index: index, - logprobs: null, - finish_reason: o.finishReason?.reason, - })), - }; - } - - return generateInvalidProviderResponseError(response, AI21); -}; diff --git a/src/providers/ai21/index.ts b/src/providers/ai21/index.ts index dc157a83c..e580e5c69 100644 --- a/src/providers/ai21/index.ts +++ b/src/providers/ai21/index.ts @@ -1,20 +1,41 @@ +import { AI21 } from '../../globals'; +import { chatCompleteParams, responseTransformers } from '../open-ai-base'; import { ProviderConfigs } from '../types'; import AI21APIConfig from './api'; -import { - AI21ChatCompleteConfig, - AI21ChatCompleteResponseTransform, -} from './chatComplete'; import { AI21CompleteConfig, AI21CompleteResponseTransform } from './complete'; import { AI21EmbedConfig, AI21EmbedResponseTransform } from './embed'; +/** + * AI21 Studio provider configuration. + * + * Chat completions use the Jamba model family via the OpenAI-compatible + * /v1/chat/completions endpoint (introduced with Jamba 1.5 and later). + * Reference: https://docs.ai21.com/reference/jamba-1-6-api-ref + * + * Legacy Jurassic-2 text completions continue to use the model-specific + * endpoint: /v1/{model}/complete + */ const AI21Config: ProviderConfigs = { + // Legacy Jurassic-2 text completion (unchanged) complete: AI21CompleteConfig, - chatComplete: AI21ChatCompleteConfig, + + // Jamba chat completions via OpenAI-compatible /v1/chat/completions endpoint. + // Excludes OpenAI params not supported by AI21 Jamba. + chatComplete: chatCompleteParams([ + 'logit_bias', + 'logprobs', + 'top_logprobs', + 'service_tier', + 'parallel_tool_calls', + ]), + embed: AI21EmbedConfig, api: AI21APIConfig, responseTransforms: { + // Spread the OpenAI-compatible transformers (sets chatComplete). + // Then override complete and embed with AI21-specific transforms. + ...responseTransformers(AI21, { chatComplete: true }), complete: AI21CompleteResponseTransform, - chatComplete: AI21ChatCompleteResponseTransform, embed: AI21EmbedResponseTransform, }, };