diff --git a/src/providers/google-vertex-ai/api.ts b/src/providers/google-vertex-ai/api.ts index a714319ef..30eb4fd71 100644 --- a/src/providers/google-vertex-ai/api.ts +++ b/src/providers/google-vertex-ai/api.ts @@ -179,6 +179,10 @@ export const GoogleApiConfig: ProviderAPIConfig = { 'imageGenerate', `${projectRoute}/publishers/${provider}/models/${model}:predict`, ], + [ + 'messagesCountTokens', + `${projectRoute}/publishers/${provider}/models/${model}:countTokens`, + ], ]); switch (provider) { diff --git a/src/providers/google-vertex-ai/geminiCountTokens.ts b/src/providers/google-vertex-ai/geminiCountTokens.ts new file mode 100644 index 000000000..d4c06157d --- /dev/null +++ b/src/providers/google-vertex-ai/geminiCountTokens.ts @@ -0,0 +1,55 @@ +import { GOOGLE_VERTEX_AI } from '../../globals'; +import { ErrorResponse, ProviderConfig } from '../types'; +import { + generateErrorResponse, + generateInvalidProviderResponseError, +} from '../utils'; +import { VertexGoogleChatCompleteConfig } from './chatComplete'; +import { GoogleErrorResponse } from './types'; + +// VertexGeminiCountTokensConfig reuses the same transforms as +// VertexGoogleChatCompleteConfig for contents, systemInstruction, tools, and +// tool_choice so that all token-contributing parameters are forwarded to the +// Vertex AI Gemini countTokens endpoint. +export const VertexGeminiCountTokensConfig: ProviderConfig = { + model: { + param: 'model', + required: true, + }, + messages: VertexGoogleChatCompleteConfig.messages, + tools: VertexGoogleChatCompleteConfig.tools, + tool_choice: VertexGoogleChatCompleteConfig.tool_choice, +}; + +interface VertexGeminiCountTokensResponse { + totalTokens: number; + cachedContentTokenCount?: number; +} + +// VertexGeminiCountTokensResponseTransform maps Vertex AI Gemini's +// { totalTokens } to the gateway's unified { input_tokens } format. +export const VertexGeminiCountTokensResponseTransform: ( + response: VertexGeminiCountTokensResponse | GoogleErrorResponse, + responseStatus: number +) => { input_tokens: number } | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && 'error' in response) { + const errorResponse = response as GoogleErrorResponse; + return generateErrorResponse( + { + message: errorResponse.error?.message ?? '', + type: errorResponse.error?.status ?? null, + param: null, + code: String(errorResponse.error?.code ?? ''), + }, + GOOGLE_VERTEX_AI + ); + } + + if ('totalTokens' in response) { + return { + input_tokens: response.totalTokens, + }; + } + + return generateInvalidProviderResponseError(response, GOOGLE_VERTEX_AI); +}; diff --git a/src/providers/google-vertex-ai/index.ts b/src/providers/google-vertex-ai/index.ts index 45199c52a..e6b7b00cc 100644 --- a/src/providers/google-vertex-ai/index.ts +++ b/src/providers/google-vertex-ai/index.ts @@ -53,6 +53,10 @@ import { VertexAnthropicMessagesResponseTransform, } from './messages'; import { VertexAnthropicMessagesCountTokensConfig } from './messagesCountTokens'; +import { + VertexGeminiCountTokensConfig, + VertexGeminiCountTokensResponseTransform, +} from './geminiCountTokens'; import { GetMistralAIChatCompleteResponseTransform, GetMistralAIChatCompleteStreamChunkTransform, @@ -112,6 +116,7 @@ const VertexConfig: ProviderConfigs = { api: GoogleApiConfig, embed: GoogleEmbedConfig, imageGenerate: GoogleImageGenConfig, + messagesCountTokens: VertexGeminiCountTokensConfig, createBatch: GoogleBatchCreateConfig, createFinetune: baseConfig.createFinetune, responseTransforms: { @@ -119,6 +124,7 @@ const VertexConfig: ProviderConfigs = { chatComplete: GoogleChatCompleteResponseTransform, embed: GoogleEmbedResponseTransform, imageGenerate: GoogleImageGenResponseTransform, + messagesCountTokens: VertexGeminiCountTokensResponseTransform, ...responseTransforms, }, requestTransforms: { diff --git a/src/providers/google/api.ts b/src/providers/google/api.ts index a842de5f7..9369bc89a 100644 --- a/src/providers/google/api.ts +++ b/src/providers/google/api.ts @@ -23,6 +23,9 @@ export const GoogleApiConfig: ProviderAPIConfig = { case 'embed': { return `/${routeVersion}/models/${model}:embedContent?key=${apiKey}`; } + case 'messagesCountTokens': { + return `/${routeVersion}/models/${model}:countTokens?key=${apiKey}`; + } default: return ''; } diff --git a/src/providers/google/countTokens.ts b/src/providers/google/countTokens.ts new file mode 100644 index 000000000..1442bb765 --- /dev/null +++ b/src/providers/google/countTokens.ts @@ -0,0 +1,49 @@ +import { GOOGLE } from '../../globals'; +import { ErrorResponse, ProviderConfig } from '../types'; +import { generateInvalidProviderResponseError } from '../utils'; +import { + GoogleChatCompleteConfig, + GoogleErrorResponse, + GoogleErrorResponseTransform, +} from './chatComplete'; + +// GoogleCountTokensConfig reuses the same transforms as GoogleChatCompleteConfig +// for contents, systemInstruction, tools, and tool_choice so that all token-contributing +// parameters are forwarded to the Gemini countTokens endpoint. +// The model param is required but does not need a default for counting. +export const GoogleCountTokensConfig: ProviderConfig = { + model: { + param: 'model', + required: true, + }, + messages: GoogleChatCompleteConfig.messages, + tools: GoogleChatCompleteConfig.tools, + tool_choice: GoogleChatCompleteConfig.tool_choice, +}; + +interface GoogleCountTokensResponse { + totalTokens: number; + cachedContentTokenCount?: number; +} + +// GoogleCountTokensResponseTransform maps Gemini's { totalTokens } to the +// gateway's unified { input_tokens } format. +export const GoogleCountTokensResponseTransform: ( + response: GoogleCountTokensResponse | GoogleErrorResponse, + responseStatus: number +) => { input_tokens: number } | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200) { + const errorResponse = GoogleErrorResponseTransform( + response as GoogleErrorResponse + ); + if (errorResponse) return errorResponse; + } + + if ('totalTokens' in response) { + return { + input_tokens: response.totalTokens, + }; + } + + return generateInvalidProviderResponseError(response, GOOGLE); +}; diff --git a/src/providers/google/index.ts b/src/providers/google/index.ts index a539f921e..6f2a83b25 100644 --- a/src/providers/google/index.ts +++ b/src/providers/google/index.ts @@ -5,16 +5,22 @@ import { GoogleChatCompleteResponseTransform, GoogleChatCompleteStreamChunkTransform, } from './chatComplete'; +import { + GoogleCountTokensConfig, + GoogleCountTokensResponseTransform, +} from './countTokens'; import { GoogleEmbedConfig, GoogleEmbedResponseTransform } from './embed'; const GoogleConfig: ProviderConfigs = { api: GoogleApiConfig, chatComplete: GoogleChatCompleteConfig, embed: GoogleEmbedConfig, + messagesCountTokens: GoogleCountTokensConfig, responseTransforms: { chatComplete: GoogleChatCompleteResponseTransform, 'stream-chatComplete': GoogleChatCompleteStreamChunkTransform, embed: GoogleEmbedResponseTransform, + messagesCountTokens: GoogleCountTokensResponseTransform, }, };