diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0a47ae416a..08244ccd34 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -749,6 +749,9 @@ importers: '@ai-sdk/cerebras': specifier: ^1.0.0 version: 1.0.35(zod@3.25.76) + '@ai-sdk/deepinfra': + specifier: ^2.0.31 + version: 2.0.31(zod@3.25.76) '@ai-sdk/deepseek': specifier: ^2.0.14 version: 2.0.14(zod@3.25.76) @@ -1429,6 +1432,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/deepinfra@2.0.31': + resolution: {integrity: sha512-kKo5hX83yrPLO4gkzLc5hgF+iNNMVYDkyXXXkdbBsgc3x1Dttl3oRO9nDAVPNkxDauZ7t1tSxLELpGo7l5QD3g==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/deepseek@2.0.14': resolution: {integrity: sha512-1vXh8sVwRJYd1JO57qdy1rACucaNLDoBRCwOER3EbPgSF2vNVPcdJywGutA01Bhn7Cta+UJQ+k5y/yzMAIpP2w==} engines: {node: '>=18'} @@ -1501,6 +1510,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.27': + resolution: {integrity: sha512-YpAZe7OQuMkYqcM/m1BMX0xFn4QdhuL4qGo8sNaiLq1VjEeU/pPfz51rnlpCfCvYanUL5TjIZEbdclBUwLooSQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -11143,6 +11158,13 @@ snapshots: '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/deepinfra@2.0.31(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.27(zod@3.25.76) + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/deepseek@2.0.14(zod@3.25.76)': dependencies: '@ai-sdk/provider': 3.0.5 @@ -11222,6 +11244,12 @@ snapshots: '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.27(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -15261,7 +15289,7 @@ snapshots: sirv: 3.0.1 tinyglobby: 0.2.14 tinyrainbow: 2.0.0 - vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.17.50)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(tsx@4.19.4)(yaml@2.8.0) '@vitest/utils@3.2.4': dependencies: diff --git a/src/api/providers/__tests__/deepinfra.spec.ts b/src/api/providers/__tests__/deepinfra.spec.ts index c4a9275762..2b7cd337ab 100644 --- a/src/api/providers/__tests__/deepinfra.spec.ts +++ b/src/api/providers/__tests__/deepinfra.spec.ts @@ -1,386 +1,527 @@ // npx vitest api/providers/__tests__/deepinfra.spec.ts -import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types" - -const mockCreate = vitest.fn() -const mockWithResponse = vitest.fn() - -vitest.mock("openai", () => { - const mockConstructor = vitest.fn() +const { mockStreamText, mockGenerateText, mockCreateDeepInfra } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), + mockCreateDeepInfra: vi.fn(), +})) +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - __esModule: true, - default: mockConstructor.mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate.mockImplementation(() => ({ - withResponse: mockWithResponse, - })), - }, - }, - })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) -vitest.mock("../fetchers/modelCache", () => ({ - getModels: vitest.fn().mockResolvedValue({ +vi.mock("@ai-sdk/deepinfra", () => ({ + createDeepInfra: mockCreateDeepInfra.mockImplementation(() => { + return vi.fn(() => ({ + modelId: "test-model", + provider: "deepinfra", + })) + }), +})) + +vi.mock("../constants", () => ({ + DEFAULT_HEADERS: { + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", + "X-Title": "Roo Code", + "User-Agent": "RooCode/test", + }, +})) + +import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types" + +vi.mock("../fetchers/modelCache", () => ({ + getModels: vi.fn().mockResolvedValue({ [deepInfraDefaultModelId]: deepInfraDefaultModelInfo, }), - getModelsFromCache: vitest.fn().mockReturnValue(undefined), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) -import OpenAI from "openai" +import type { Anthropic } from "@anthropic-ai/sdk" +import type { ApiHandlerOptions } from "../../../shared/api" import { DeepInfraHandler } from "../deepinfra" describe("DeepInfraHandler", () => { let handler: DeepInfraHandler + let mockOptions: ApiHandlerOptions beforeEach(() => { + mockOptions = { + deepInfraApiKey: "test-api-key", + deepInfraModelId: deepInfraDefaultModelId, + } + handler = new DeepInfraHandler(mockOptions) vi.clearAllMocks() - mockCreate.mockClear() - mockWithResponse.mockClear() - - handler = new DeepInfraHandler({}) - }) - - it("should use the correct DeepInfra base URL", () => { - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ - baseURL: "https://api.deepinfra.com/v1/openai", - }), - ) }) - it("should use the provided API key", () => { - vi.clearAllMocks() - - const deepInfraApiKey = "test-api-key" - new DeepInfraHandler({ deepInfraApiKey }) + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(DeepInfraHandler) + expect(handler.getModel().id).toBe(deepInfraDefaultModelId) + }) - expect(OpenAI).toHaveBeenCalledWith( - expect.objectContaining({ - apiKey: deepInfraApiKey, - }), - ) + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new DeepInfraHandler({ + ...mockOptions, + deepInfraModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(deepInfraDefaultModelId) + }) }) - it("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(deepInfraDefaultModelId) - expect(model.info).toEqual(deepInfraDefaultModelInfo) - }) + describe("constructor provider creation", () => { + it("should create provider with correct options including DEFAULT_HEADERS", () => { + const testHandler = new DeepInfraHandler({ + deepInfraApiKey: "my-key", + deepInfraBaseUrl: "https://custom.deepinfra.com/v1", + }) - it("createMessage should yield text content from stream", async () => { - const testContent = "This is test content" - - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { content: testContent } }], - }, - }) - .mockResolvedValueOnce({ done: true }), + expect(testHandler).toBeInstanceOf(DeepInfraHandler) + expect(mockCreateDeepInfra).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "my-key", + baseURL: "https://custom.deepinfra.com/v1", + headers: { + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline", + "X-Title": "Roo Code", + "User-Agent": "RooCode/test", + "X-Deepinfra-Source": "roo-code", + "X-Deepinfra-Version": "2025-08-25", + }, }), - }, + ) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should use default base URL when not provided", () => { + const testHandler = new DeepInfraHandler({ + deepInfraApiKey: "test-key", + }) - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "text", - text: testContent, + expect(testHandler).toBeInstanceOf(DeepInfraHandler) + expect(mockCreateDeepInfra).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://api.deepinfra.com/v1/openai", + }), + ) }) - }) - it("createMessage should yield reasoning content from stream", async () => { - const testReasoning = "Test reasoning content" - - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { reasoning_content: testReasoning } }], - }, - }) - .mockResolvedValueOnce({ done: true }), + it("should use 'not-provided' as API key when not set", () => { + const handlerWithoutKey = new DeepInfraHandler({}) + + expect(handlerWithoutKey).toBeInstanceOf(DeepInfraHandler) + expect(mockCreateDeepInfra).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "not-provided", }), - }, + ) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() + it("should cache provider instance (not re-create per request)", () => { + const testHandler = new DeepInfraHandler({ + deepInfraApiKey: "my-key", + }) + mockCreateDeepInfra.mockClear() + + // Access getLanguageModel multiple times - should NOT call createDeepInfra again + ;(testHandler as any).getLanguageModel("model-1") + ;(testHandler as any).getLanguageModel("model-2") - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "reasoning", - text: testReasoning, + expect(mockCreateDeepInfra).not.toHaveBeenCalled() }) }) - it("createMessage should yield usage data from stream", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], - usage: { - prompt_tokens: 10, - completion_tokens: 20, - prompt_tokens_details: { - cache_write_tokens: 15, - cached_tokens: 5, - }, - }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - }, + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const model = handler.getModel() + expect(model.id).toBe(deepInfraDefaultModelId) + expect(model.info).toEqual(deepInfraDefaultModelInfo) + }) + + it("should return default model info for unknown model", () => { + const handlerWithUnknown = new DeepInfraHandler({ + ...mockOptions, + deepInfraModelId: "unknown/model", + }) + const model = handlerWithUnknown.getModel() + expect(model.id).toBe("unknown/model") + expect(model.info).toEqual(deepInfraDefaultModelInfo) }) - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 20, - cacheWriteTokens: 15, - cacheReadTokens: 5, - totalCost: expect.any(Number), + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") }) }) - describe("Native Tool Calling", () => { - const testTools = [ + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ { - type: "function" as const, - function: { - name: "test_tool", - description: "A test tool", - parameters: { - type: "object", - properties: { - arg1: { type: "string", description: "First argument" }, - }, - required: ["arg1"], + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", }, - }, + ], }, ] - it("should include tools in request when model supports native tools and tools are provided", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - }, + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, }) - const messageGenerator = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, }) - await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - }), - ) - // parallel_tool_calls should be true by default when not explicitly set - const callArgs = mockCreate.mock.calls[0][0] - expect(callArgs).toHaveProperty("parallel_tool_calls", true) + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response") }) - it("should include tool_choice when provided", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + it("should handle reasoning content in stream", async () => { + async function* mockFullStream() { + yield { type: "reasoning", text: "Let me think..." } + yield { type: "text-delta", text: "Answer" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0].text).toBe("Let me think...") + + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Answer") + }) + + it("should include usage information with cost calculation", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + expect(usageChunks[0].totalCost).toEqual(expect.any(Number)) + }) + + it("should pass tools and toolChoice to streamText", async () => { + const testTools = [ + { + type: "function" as const, + function: { + name: "test_tool", + description: "A test tool", + parameters: { + type: "object", + properties: { + arg1: { type: "string", description: "First argument" }, + }, + required: ["arg1"], }, - }), + }, }, + ] + + async function* mockFullStream() { + yield { type: "finish", finishReason: "stop" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), }) - const messageGenerator = handler.createMessage("test prompt", [], { + const stream = handler.createMessage(systemPrompt, messages, { taskId: "test-task-id", tools: testTools, tool_choice: "auto", }) - await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tool_choice: "auto", + system: systemPrompt, + tools: expect.any(Object), }), ) }) - it("should always include tools and tool_choice in request (tools are always present after PR #10841)", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - }, - }) + it("should handle tool call streaming events", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "call_123", + toolName: "test_tool", + } + yield { + type: "tool-input-delta", + id: "call_123", + delta: '{"arg1":', + } + yield { + type: "tool-input-delta", + id: "call_123", + delta: '"value"}', + } + yield { + type: "tool-input-end", + id: "call_123", + } + } - const messageGenerator = handler.createMessage("test prompt", [], { - taskId: "test-task-id", + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), }) - await messageGenerator.next() - - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - // parallel_tool_calls should be true by default when not explicitly set - expect(callArgs).toHaveProperty("parallel_tool_calls", true) - }) - it("should yield tool_call_partial chunks during streaming", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - }, - }) - - const stream = handler.createMessage("test prompt", [], { + const stream = handler.createMessage(systemPrompt, messages, { taskId: "test-task-id", - tools: testTools, }) - const chunks = [] + const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "test_tool", - arguments: '{"arg1":', - }) + const startChunks = chunks.filter((chunk) => chunk.type === "tool_call_start") + expect(startChunks).toHaveLength(1) + expect(startChunks[0].id).toBe("call_123") + expect(startChunks[0].name).toBe("test_tool") - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', - }) + const deltaChunks = chunks.filter((chunk) => chunk.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(2) }) - it("should set parallel_tool_calls based on metadata", async () => { - mockWithResponse.mockResolvedValueOnce({ - data: { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - }, + it("should handle errors using handleAiSdkError", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "" } + throw new Error("API error") + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), }) - const messageGenerator = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - parallelToolCalls: true, + const stream = handler.createMessage(systemPrompt, messages) + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow() + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", }) - await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion") + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - parallel_tool_calls: true, + prompt: "Test prompt", }), ) }) }) - describe("completePrompt", () => { - it("should return text from API", async () => { - const expectedResponse = "This is a test response" - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: expectedResponse } }], - }) + describe("processUsageMetrics", () => { + it("should correctly calculate cost with model info", () => { + class TestDeepInfraHandler extends DeepInfraHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any, modelInfo?: any) { + return this.processUsageMetrics(usage, providerMetadata, modelInfo) + } + } + + const testHandler = new TestDeepInfraHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage, {}, deepInfraDefaultModelInfo) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.totalCost).toEqual(expect.any(Number)) + expect(result.totalCost).toBeGreaterThan(0) + }) + + it("should handle cache metrics from providerMetadata", () => { + class TestDeepInfraHandler extends DeepInfraHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any, modelInfo?: any) { + return this.processUsageMetrics(usage, providerMetadata, modelInfo) + } + } + + const testHandler = new TestDeepInfraHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + deepinfra: { + cacheWriteTokens: 15, + cachedTokens: 5, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata, deepInfraDefaultModelInfo) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(15) + expect(result.cacheReadTokens).toBe(5) + expect(result.totalCost).toEqual(expect.any(Number)) + }) + + it("should handle missing cache metrics gracefully", () => { + class TestDeepInfraHandler extends DeepInfraHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any, modelInfo?: any) { + return this.processUsageMetrics(usage, providerMetadata, modelInfo) + } + } + + const testHandler = new TestDeepInfraHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage, {}, deepInfraDefaultModelInfo) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() + }) + + it("should return zero cost without model info", () => { + class TestDeepInfraHandler extends DeepInfraHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any, modelInfo?: any) { + return this.processUsageMetrics(usage, providerMetadata, modelInfo) + } + } + + const testHandler = new TestDeepInfraHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.totalCost).toBe(0) + }) + + it("should use cachedInputTokens from usage details as fallback", () => { + class TestDeepInfraHandler extends DeepInfraHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any, modelInfo?: any) { + return this.processUsageMetrics(usage, providerMetadata, modelInfo) + } + } + + const testHandler = new TestDeepInfraHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, {}, deepInfraDefaultModelInfo) + + expect(result.cacheReadTokens).toBe(25) + }) + }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) + describe("isAiSdkProvider", () => { + it("should return true", () => { + expect(handler.isAiSdkProvider()).toBe(true) }) }) }) diff --git a/src/api/providers/deepinfra.ts b/src/api/providers/deepinfra.ts index e5b10e4e44..2333ea046f 100644 --- a/src/api/providers/deepinfra.ts +++ b/src/api/providers/deepinfra.ts @@ -1,44 +1,76 @@ import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" +import { createDeepInfra } from "@ai-sdk/deepinfra" +import { streamText, generateText, ToolSet } from "ai" -import { deepInfraDefaultModelId, deepInfraDefaultModelInfo } from "@roo-code/types" +import { deepInfraDefaultModelId, deepInfraDefaultModelInfo, type ModelInfo, type ModelRecord } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" import { calculateApiCostOpenAI } from "../../shared/cost" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { getModelParams } from "../transform/model-params" +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" -import { RouterProvider } from "./router-provider" -import { getModelParams } from "../transform/model-params" -import { getModels } from "./fetchers/modelCache" +import { getModels, getModelsFromCache } from "./fetchers/modelCache" + +const DEEPINFRA_DEFAULT_BASE_URL = "https://api.deepinfra.com/v1/openai" + +const DEEPINFRA_HEADERS = { + "X-Deepinfra-Source": "roo-code", + "X-Deepinfra-Version": "2025-08-25", +} + +/** + * DeepInfra provider using the official @ai-sdk/deepinfra package. + * Supports dynamic model fetching, reasoning_effort, prompt caching, and custom cost calculation. + */ +export class DeepInfraHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType + protected models: ModelRecord = {} -export class DeepInfraHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super({ - options: { - ...options, - openAiHeaders: { - "X-Deepinfra-Source": "roo-code", - "X-Deepinfra-Version": `2025-08-25`, - }, + super() + this.options = options + + // Create the DeepInfra provider once in the constructor (cached) + this.provider = createDeepInfra({ + apiKey: this.options.deepInfraApiKey ?? "not-provided", + baseURL: this.options.deepInfraBaseUrl || DEEPINFRA_DEFAULT_BASE_URL, + headers: { + ...DEFAULT_HEADERS, + ...DEEPINFRA_HEADERS, }, - name: "deepinfra", - baseURL: `${options.deepInfraBaseUrl || "https://api.deepinfra.com/v1/openai"}`, - apiKey: options.deepInfraApiKey || "not-provided", - modelId: options.deepInfraModelId, - defaultModelId: deepInfraDefaultModelId, - defaultModelInfo: deepInfraDefaultModelInfo, }) } - public override async fetchModel() { - this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL }) + /** + * Fetch models dynamically from the DeepInfra API and return the resolved model. + */ + async fetchModel() { + this.models = await getModels({ + provider: "deepinfra", + apiKey: this.options.deepInfraApiKey, + baseUrl: this.options.deepInfraBaseUrl || DEEPINFRA_DEFAULT_BASE_URL, + }) return this.getModel() } override getModel() { + const cachedModels = getModelsFromCache("deepinfra") + if (cachedModels) { + this.models = cachedModels + } + const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId const info = this.models[id] ?? deepInfraDefaultModelInfo @@ -52,112 +84,137 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion return { id, info, ...params } } + /** + * Get the language model for the given model ID. + */ + protected getLanguageModel(modelId: string) { + return this.provider(modelId) + } + + /** + * Process usage metrics with DeepInfra-specific cost calculation using calculateApiCostOpenAI. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: Record, + modelInfo?: ModelInfo, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens || 0 + const outputTokens = usage.outputTokens || 0 + + const cacheWriteTokens = providerMetadata?.deepinfra?.cacheWriteTokens ?? undefined + const cacheReadTokens = + providerMetadata?.deepinfra?.cachedTokens ?? usage.details?.cachedInputTokens ?? undefined + + const { totalCost } = modelInfo + ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + : { totalCost: 0 } + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens: cacheWriteTokens || undefined, + cacheReadTokens: cacheReadTokens || undefined, + totalCost, + } + } + + /** + * Get the max output tokens parameter, only when includeMaxTokens is enabled. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + if (this.options.includeMaxTokens !== true || !info.maxTokens) { + return undefined + } + return this.options.modelMaxTokens || info.maxTokens + } + + /** + * Create a message stream using the AI SDK. + * Handles dynamic model fetching, reasoning_effort, prompt caching, and custom cost metrics. + */ override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], - _metadata?: ApiHandlerCreateMessageMetadata, + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - // Ensure we have up-to-date model metadata - await this.fetchModel() - const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel() - let prompt_cache_key = undefined - if (info.supportsPromptCache && _metadata?.taskId) { - prompt_cache_key = _metadata.taskId - } - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: modelId, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - reasoning_effort, - prompt_cache_key, - tools: this.convertToolsForOpenAI(_metadata?.tools), - tool_choice: _metadata?.tool_choice, - parallel_tool_calls: _metadata?.parallelToolCalls ?? true, - } as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming - - if (this.supportsTemperature(modelId)) { - requestOptions.temperature = this.options.modelTemperature ?? 0 - } + const { id: modelId, info, temperature, reasoningEffort } = await this.fetchModel() + const languageModel = this.getLanguageModel(modelId) - if (this.options.includeMaxTokens === true && info.maxTokens) { - ;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens - } + const aiSdkMessages = convertToAiSdkMessages(messages) - const { data: stream } = await this.client.chat.completions.create(requestOptions).withResponse() + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined - let lastUsage: OpenAI.CompletionUsage | undefined - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta + // Build DeepInfra-specific provider options + const deepinfraProviderOptions: Record = {} + if (reasoningEffort) { + deepinfraProviderOptions.reasoningEffort = reasoningEffort + } + if (info.supportsPromptCache && metadata?.taskId) { + deepinfraProviderOptions.promptCacheKey = metadata.taskId + } - if (delta?.content) { - yield { type: "text", text: delta.content } - } + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + ...(Object.keys(deepinfraProviderOptions).length > 0 && { + providerOptions: { deepinfra: deepinfraProviderOptions }, + }), + } - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" } - } + const result = streamText(requestOptions) - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } + try { + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - if (chunk.usage) { - lastUsage = chunk.usage + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any, info) } - } - - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, info) + } catch (error) { + throw handleAiSdkError(error, "DeepInfra") } } + /** + * Complete a prompt using AI SDK generateText. + */ async completePrompt(prompt: string): Promise { await this.fetchModel() - const { id: modelId, info } = this.getModel() - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: modelId, - messages: [{ role: "user", content: prompt }], - } - if (this.supportsTemperature(modelId)) { - requestOptions.temperature = this.options.modelTemperature ?? 0 - } - if (this.options.includeMaxTokens === true && info.maxTokens) { - ;(requestOptions as any).max_completion_tokens = this.options.modelMaxTokens || info.maxTokens - } + const { id: modelId, temperature } = this.getModel() + const languageModel = this.getLanguageModel(modelId) + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature, + }) - const resp = await this.client.chat.completions.create(requestOptions) - return resp.choices[0]?.message?.content || "" + return text } - protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk { - const inputTokens = usage?.prompt_tokens || 0 - const outputTokens = usage?.completion_tokens || 0 - const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0 - const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 - - const { totalCost } = modelInfo - ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) - : { totalCost: 0 } - - return { - type: "usage", - inputTokens, - outputTokens, - cacheWriteTokens: cacheWriteTokens || undefined, - cacheReadTokens: cacheReadTokens || undefined, - totalCost, - } + override isAiSdkProvider(): boolean { + return true } } diff --git a/src/package.json b/src/package.json index 3e0201c641..9bcc4b4a25 100644 --- a/src/package.json +++ b/src/package.json @@ -451,6 +451,7 @@ }, "dependencies": { "@ai-sdk/cerebras": "^1.0.0", + "@ai-sdk/deepinfra": "^2.0.31", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26", "@ai-sdk/google": "^3.0.20", @@ -458,7 +459,6 @@ "@ai-sdk/groq": "^3.0.19", "@ai-sdk/mistral": "^3.0.0", "@ai-sdk/xai": "^3.0.46", - "sambanova-ai-provider": "^1.2.2", "@anthropic-ai/bedrock-sdk": "^0.10.2", "@anthropic-ai/sdk": "^0.37.0", "@anthropic-ai/vertex-sdk": "^0.7.0", @@ -518,6 +518,7 @@ "puppeteer-core": "^23.4.0", "reconnecting-eventsource": "^1.6.4", "safe-stable-stringify": "^2.5.0", + "sambanova-ai-provider": "^1.2.2", "sanitize-filename": "^1.6.3", "say": "^0.16.0", "semver-compare": "^1.0.0",