Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/providers/ai21/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ const AI21APIConfig: ProviderAPIConfig = {
const { model } = gatewayRequestBodyJSON;
switch (fn) {
case 'complete': {
// Legacy Jurassic-2 models use the model-specific completion endpoint
return `/${model}/complete`;
}
case 'chatComplete': {
return `/${model}/chat`;
// Jamba models (jamba-1.5-*, jamba-1.6-*, jamba-instruct, etc.)
// use the OpenAI-compatible /chat/completions endpoint.
// Reference: https://docs.ai21.com/reference/jamba-1-6-api-ref
return `/chat/completions`;
}
case 'embed': {
return `/embed`;
Expand Down
154 changes: 2 additions & 152 deletions src/providers/ai21/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,125 +1,8 @@
import { AI21 } from '../../globals';
import { Params, SYSTEM_MESSAGE_ROLES } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
ProviderConfig,
} from '../types';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';
import { ErrorResponse } from '../types';
import { generateErrorResponse } from '../utils';
import { AI21ErrorResponse } from './complete';

export const AI21ChatCompleteConfig: ProviderConfig = {
messages: [
{
param: 'messages',
required: true,
transform: (params: Params) => {
let inputMessages: any = [];

if (
params.messages?.[0]?.role &&
SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role)
) {
inputMessages = params.messages.slice(1);
} else if (params.messages) {
inputMessages = params.messages;
}

return inputMessages.map((msg: any) => ({
text: msg.content,
role: msg.role,
}));
},
},
{
param: 'system',
required: false,
transform: (params: Params) => {
if (
params.messages?.[0]?.role &&
SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role)
) {
return params.messages?.[0].content;
}
},
},
],
n: {
param: 'numResults',
default: 1,
},
max_tokens: {
param: 'maxTokens',
default: 16,
},
max_completion_tokens: {
param: 'maxTokens',
default: 16,
},
minTokens: {
param: 'minTokens',
default: 0,
},
temperature: {
param: 'temperature',
default: 0.7,
min: 0,
max: 1,
},
top_p: {
param: 'topP',
default: 1,
},
top_k: {
param: 'topKReturn',
default: 0,
},
stop: {
param: 'stopSequences',
},
presence_penalty: {
param: 'presencePenalty',
transform: (params: Params) => {
return {
scale: params.presence_penalty,
};
},
},
frequency_penalty: {
param: 'frequencyPenalty',
transform: (params: Params) => {
return {
scale: params.frequency_penalty,
};
},
},
countPenalty: {
param: 'countPenalty',
},
frequencyPenalty: {
param: 'frequencyPenalty',
},
presencePenalty: {
param: 'presencePenalty',
},
};

interface AI21ChatCompleteResponse {
id: string;
outputs: {
text: string;
role: string;
finishReason: {
reason: string;
length: number | null;
sequence: string | null;
};
}[];
}

export const AI21ErrorResponseTransform: (
response: AI21ErrorResponse
) => ErrorResponse | undefined = (response) => {
Expand All @@ -132,36 +15,3 @@ export const AI21ErrorResponseTransform: (

return undefined;
};

export const AI21ChatCompleteResponseTransform: (
response: AI21ChatCompleteResponse | AI21ErrorResponse,
responseStatus: number
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResposne = AI21ErrorResponseTransform(
response as AI21ErrorResponse
);
if (errorResposne) return errorResposne;
}

if ('outputs' in response) {
return {
id: response.id,
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: AI21,
choices: response.outputs.map((o, index) => ({
message: {
role: 'assistant',
content: o.text,
},
index: index,
logprobs: null,
finish_reason: o.finishReason?.reason,
})),
};
}

return generateInvalidProviderResponseError(response, AI21);
};
33 changes: 27 additions & 6 deletions src/providers/ai21/index.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
import { AI21 } from '../../globals';
import { chatCompleteParams, responseTransformers } from '../open-ai-base';
import { ProviderConfigs } from '../types';
import AI21APIConfig from './api';
import {
AI21ChatCompleteConfig,
AI21ChatCompleteResponseTransform,
} from './chatComplete';
import { AI21CompleteConfig, AI21CompleteResponseTransform } from './complete';
import { AI21EmbedConfig, AI21EmbedResponseTransform } from './embed';

/**
* AI21 Studio provider configuration.
*
* Chat completions use the Jamba model family via the OpenAI-compatible
* /v1/chat/completions endpoint (introduced with Jamba 1.5 and later).
* Reference: https://docs.ai21.com/reference/jamba-1-6-api-ref
*
* Legacy Jurassic-2 text completions continue to use the model-specific
* endpoint: /v1/{model}/complete
*/
const AI21Config: ProviderConfigs = {
// Legacy Jurassic-2 text completion (unchanged)
complete: AI21CompleteConfig,
chatComplete: AI21ChatCompleteConfig,

// Jamba chat completions via OpenAI-compatible /v1/chat/completions endpoint.
// Excludes OpenAI params not supported by AI21 Jamba.
chatComplete: chatCompleteParams([
'logit_bias',
'logprobs',
'top_logprobs',
'service_tier',
'parallel_tool_calls',
]),

embed: AI21EmbedConfig,
api: AI21APIConfig,
responseTransforms: {
// Spread the OpenAI-compatible transformers (sets chatComplete).
// Then override complete and embed with AI21-specific transforms.
...responseTransformers(AI21, { chatComplete: true }),
complete: AI21CompleteResponseTransform,
chatComplete: AI21ChatCompleteResponseTransform,
embed: AI21EmbedResponseTransform,
},
};
Expand Down