diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0a47ae416ae..5f1ce8fd785 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -746,6 +746,9 @@ importers: src: dependencies: + '@ai-sdk/anthropic': + specifier: ^3.0.37 + version: 3.0.37(zod@3.25.76) '@ai-sdk/cerebras': specifier: ^1.0.0 version: 1.0.35(zod@3.25.76) @@ -1423,6 +1426,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/anthropic@3.0.37': + resolution: {integrity: sha512-tEgcJPw+a6obbF+SHrEiZsx3DNxOHqeY8bK4IpiNsZ8YPZD141R34g3lEAaQnmNN5mGsEJ8SXoEDabuzi8wFJQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35': resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} engines: {node: '>=18'} @@ -11136,6 +11145,12 @@ snapshots: '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/anthropic@3.0.37(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) diff --git a/src/api/providers/__tests__/anthropic.spec.ts b/src/api/providers/__tests__/anthropic.spec.ts index 7a107edbc8b..71396c330a6 100644 --- a/src/api/providers/__tests__/anthropic.spec.ts +++ b/src/api/providers/__tests__/anthropic.spec.ts @@ -1,90 +1,56 @@ // npx vitest run src/api/providers/__tests__/anthropic.spec.ts -import { AnthropicHandler } from "../anthropic" -import { ApiHandlerOptions } from "../../../shared/api" +const mockCaptureException = vitest.fn() -// Mock TelemetryService vitest.mock("@roo-code/telemetry", () => ({ TelemetryService: { instance: { - captureException: vitest.fn(), + captureException: (...args: unknown[]) => mockCaptureException(...args), }, }, })) -const mockCreate = vitest.fn() - -vitest.mock("@anthropic-ai/sdk", () => { - const mockAnthropicConstructor = vitest.fn().mockImplementation(() => ({ - messages: { - create: mockCreate.mockImplementation(async (options) => { - if (!options.stream) { - return { - id: "test-completion", - content: [{ type: "text", text: "Test response" }], - role: "assistant", - model: options.model, - usage: { - input_tokens: 10, - output_tokens: 5, - }, - } - } - return { - async *[Symbol.asyncIterator]() { - yield { - type: "message_start", - message: { - usage: { - input_tokens: 100, - output_tokens: 50, - cache_creation_input_tokens: 20, - cache_read_input_tokens: 10, - }, - }, - } - yield { - type: "content_block_start", - index: 0, - content_block: { - type: "text", - text: "Hello", - }, - } - yield { - type: "content_block_delta", - delta: { - type: "text_delta", - text: " world", - }, - } - }, - } - }), - }, - })) +// Mock the AI SDK functions +const mockStreamText = vitest.fn() +const mockGenerateText = vitest.fn() +vitest.mock("ai", async (importOriginal) => { + const original = await importOriginal() return { - Anthropic: mockAnthropicConstructor, + ...original, + streamText: (...args: unknown[]) => mockStreamText(...args), + generateText: (...args: unknown[]) => mockGenerateText(...args), } }) -// Import after mock -import { Anthropic } from "@anthropic-ai/sdk" +// Mock createAnthropic to capture constructor options +const mockCreateAnthropic = vitest.fn().mockReturnValue(() => ({})) + +vitest.mock("@ai-sdk/anthropic", () => ({ + createAnthropic: (...args: unknown[]) => mockCreateAnthropic(...args), +})) -const mockAnthropicConstructor = vitest.mocked(Anthropic) +import { Anthropic } from "@anthropic-ai/sdk" +import { type ModelInfo, anthropicDefaultModelId, ApiProviderError } from "@roo-code/types" +import { AnthropicHandler } from "../anthropic" +import { ApiHandlerOptions } from "../../../shared/api" describe("AnthropicHandler", () => { let handler: AnthropicHandler let mockOptions: ApiHandlerOptions beforeEach(() => { + mockCaptureException.mockClear() + mockStreamText.mockClear() + mockGenerateText.mockClear() + mockCreateAnthropic.mockClear() + mockCreateAnthropic.mockReturnValue(() => ({})) + mockOptions = { apiKey: "test-api-key", apiModelId: "claude-3-5-sonnet-20241022", } handler = new AnthropicHandler(mockOptions) - vitest.clearAllMocks() }) describe("constructor", () => { @@ -93,653 +59,625 @@ describe("AnthropicHandler", () => { expect(handler.getModel().id).toBe(mockOptions.apiModelId) }) - it("should initialize with undefined API key", () => { - // The SDK will handle API key validation, so we just verify it initializes - const handlerWithoutKey = new AnthropicHandler({ - ...mockOptions, - apiKey: undefined, - }) - expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler) + it("should create provider with apiKey", () => { + mockCreateAnthropic.mockClear() + new AnthropicHandler(mockOptions) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "test-api-key", + }), + ) }) it("should use custom base URL if provided", () => { - const customBaseUrl = "https://custom.anthropic.com" - const handlerWithCustomUrl = new AnthropicHandler({ + mockCreateAnthropic.mockClear() + new AnthropicHandler({ ...mockOptions, - anthropicBaseUrl: customBaseUrl, + anthropicBaseUrl: "https://custom.anthropic.com", }) - expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: "https://custom.anthropic.com", + }), + ) + }) + + it("should pass undefined baseURL when anthropicBaseUrl is not provided", () => { + mockCreateAnthropic.mockClear() + new AnthropicHandler(mockOptions) + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: undefined, + }), + ) }) - it("use apiKey for passing token if anthropicUseAuthToken is not set", () => { - const handlerWithCustomUrl = new AnthropicHandler({ + it("should use Bearer auth when anthropicBaseUrl and anthropicUseAuthToken are set", () => { + mockCreateAnthropic.mockClear() + new AnthropicHandler({ ...mockOptions, + anthropicBaseUrl: "https://custom.anthropic.com", + anthropicUseAuthToken: true, }) - expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler) - expect(mockAnthropicConstructor).toHaveBeenCalledTimes(1) - expect(mockAnthropicConstructor.mock.calls[0]![0]!.apiKey).toEqual("test-api-key") - expect(mockAnthropicConstructor.mock.calls[0]![0]!.authToken).toBeUndefined() + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "", + headers: expect.objectContaining({ + Authorization: "Bearer test-api-key", + }), + }), + ) }) - it("use apiKey for passing token if anthropicUseAuthToken is set but custom base URL is not given", () => { - const handlerWithCustomUrl = new AnthropicHandler({ + it("should use apiKey auth when anthropicUseAuthToken is set but no base URL", () => { + mockCreateAnthropic.mockClear() + new AnthropicHandler({ ...mockOptions, anthropicUseAuthToken: true, }) - expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler) - expect(mockAnthropicConstructor).toHaveBeenCalledTimes(1) - expect(mockAnthropicConstructor.mock.calls[0]![0]!.apiKey).toEqual("test-api-key") - expect(mockAnthropicConstructor.mock.calls[0]![0]!.authToken).toBeUndefined() + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "test-api-key", + }), + ) + // Should not have Authorization header + const calledHeaders = mockCreateAnthropic.mock.calls[0][0].headers + expect(calledHeaders.Authorization).toBeUndefined() }) - it("use authToken for passing token if both of anthropicBaseUrl and anthropicUseAuthToken are set", () => { - const customBaseUrl = "https://custom.anthropic.com" - const handlerWithCustomUrl = new AnthropicHandler({ + it("should initialize with undefined API key", () => { + const handlerWithoutKey = new AnthropicHandler({ ...mockOptions, - anthropicBaseUrl: customBaseUrl, - anthropicUseAuthToken: true, + apiKey: undefined, }) - expect(handlerWithCustomUrl).toBeInstanceOf(AnthropicHandler) - expect(mockAnthropicConstructor).toHaveBeenCalledTimes(1) - expect(mockAnthropicConstructor.mock.calls[0]![0]!.authToken).toEqual("test-api-key") - expect(mockAnthropicConstructor.mock.calls[0]![0]!.apiKey).toBeUndefined() + expect(handlerWithoutKey).toBeInstanceOf(AnthropicHandler) }) }) describe("createMessage", () => { const systemPrompt = "You are a helpful assistant." + const mockMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "First message" }], + }, + { + role: "assistant", + content: [{ type: "text" as const, text: "Response" }], + }, + { + role: "user", + content: [{ type: "text" as const, text: "Second message" }], + }, + ] - it("should handle prompt caching for supported models", async () => { - const stream = handler.createMessage(systemPrompt, [ - { - role: "user", - content: [{ type: "text" as const, text: "First message" }], - }, - { - role: "assistant", - content: [{ type: "text" as const, text: "Response" }], - }, - { - role: "user", - content: [{ type: "text" as const, text: "Second message" }], - }, - ]) + it("should handle text messages correctly", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world!" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + providerMetadata: Promise.resolve({}), + }) + const stream = handler.createMessage(systemPrompt, mockMessages) const chunks: any[] = [] + for await (const chunk of stream) { chunks.push(chunk) } - // Verify usage information - const usageChunk = chunks.find((chunk) => chunk.type === "usage") - expect(usageChunk).toBeDefined() - expect(usageChunk?.inputTokens).toBe(100) - expect(usageChunk?.outputTokens).toBe(50) - expect(usageChunk?.cacheWriteTokens).toBe(20) - expect(usageChunk?.cacheReadTokens).toBe(10) - - // Verify text content - const textChunks = chunks.filter((chunk) => chunk.type === "text") + // Should have text chunks + usage + const textChunks = chunks.filter((c) => c.type === "text") expect(textChunks).toHaveLength(2) - expect(textChunks[0].text).toBe("Hello") - expect(textChunks[1].text).toBe(" world") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world!" }) + + // Should have usage chunk + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0]).toMatchObject({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + }) - // Verify API - expect(mockCreate).toHaveBeenCalled() + // Verify streamText was called + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0, + }), + ) }) - }) - describe("completePrompt", () => { - it("should complete prompt successfully", async () => { - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: [{ role: "user", content: "Test prompt" }], - max_tokens: 8192, - temperature: 0, - thinking: undefined, - stream: false, + it("should pass beta headers", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hi" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - }) - it("should handle API errors", async () => { - mockCreate.mockRejectedValueOnce(new Error("Anthropic completion error: API Error")) - await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error") - }) + const stream = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream) { + // consume + } - it("should handle non-text content", async () => { - mockCreate.mockImplementationOnce(async () => ({ - content: [{ type: "image" }], - })) - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("") + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.objectContaining({ + "anthropic-beta": expect.stringContaining("fine-grained-tool-streaming-2025-05-14"), + }), + }), + ) }) - it("should handle empty response", async () => { - mockCreate.mockImplementationOnce(async () => ({ - content: [{ type: "text", text: "" }], - })) - const result = await handler.completePrompt("Test prompt") - expect(result).toBe("") - }) - }) + it("should include prompt-caching beta for supported models", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hi" } + })() - describe("getModel", () => { - it("should return default model if no model ID is provided", () => { - const handlerWithoutModel = new AnthropicHandler({ - ...mockOptions, - apiModelId: undefined, + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - const model = handlerWithoutModel.getModel() - expect(model.id).toBeDefined() - expect(model.info).toBeDefined() - }) - it("should return specified model if valid model ID is provided", () => { - const model = handler.getModel() - expect(model.id).toBe(mockOptions.apiModelId) - expect(model.info).toBeDefined() - expect(model.info.maxTokens).toBe(8192) - expect(model.info.contextWindow).toBe(200_000) - expect(model.info.supportsImages).toBe(true) - expect(model.info.supportsPromptCache).toBe(true) - }) - - it("honors custom maxTokens for thinking models", () => { - const handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-3-7-sonnet-20250219:thinking", - modelMaxTokens: 32_768, - modelMaxThinkingTokens: 16_384, - }) + const stream = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream) { + // consume + } - const result = handler.getModel() - expect(result.maxTokens).toBe(32_768) - expect(result.reasoningBudget).toEqual(16_384) - expect(result.temperature).toBe(1.0) + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.objectContaining({ + "anthropic-beta": expect.stringContaining("prompt-caching-2024-07-31"), + }), + }), + ) }) - it("does not honor custom maxTokens for non-thinking models", () => { - const handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-3-7-sonnet-20250219", - modelMaxTokens: 32_768, - modelMaxThinkingTokens: 16_384, + it("should use system array with cache control for cache-supported models", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hi" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - const result = handler.getModel() - expect(result.maxTokens).toBe(8192) - expect(result.reasoningBudget).toBeUndefined() - expect(result.temperature).toBe(0) - }) + const stream = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream) { + // consume + } - it("should handle Claude 4.5 Sonnet model correctly", () => { - const handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-sonnet-4-5", + // For cache-supported models, system should be an array with providerOptions + const callArgs = mockStreamText.mock.calls[0][0] + expect(Array.isArray(callArgs.system)).toBe(true) + expect(callArgs.system[0]).toEqual({ + type: "text", + text: systemPrompt, + providerOptions: { anthropic: { cacheControl: { type: "ephemeral" } } }, }) - const model = handler.getModel() - expect(model.id).toBe("claude-sonnet-4-5") - expect(model.info.maxTokens).toBe(64000) - expect(model.info.contextWindow).toBe(200000) - expect(model.info.supportsReasoningBudget).toBe(true) }) - it("should enable 1M context for Claude 4.5 Sonnet when beta flag is set", () => { - const handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-sonnet-4-5", - anthropicBeta1MContext: true, + it("should handle API errors", async () => { + const mockError = new Error("Anthropic API error") + // eslint-disable-next-line require-yield + const mockFullStream = (async function* () { + throw mockError + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), }) - const model = handler.getModel() - expect(model.info.contextWindow).toBe(1000000) - expect(model.info.inputPrice).toBe(6.0) - expect(model.info.outputPrice).toBe(22.5) - }) - }) - describe("reasoning block filtering", () => { - const systemPrompt = "You are a helpful assistant." + const stream = handler.createMessage(systemPrompt, mockMessages) - it("should filter out internal reasoning blocks before sending to API", async () => { - handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-3-5-sonnet-20241022", - }) + await expect(async () => { + for await (const _chunk of stream) { + // Should throw + } + }).rejects.toThrow("Anthropic API error") - // Messages with internal reasoning blocks (from stored conversation history) - const messagesWithReasoning: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello", - }, - { - role: "assistant", - content: [ - { - type: "reasoning" as any, - text: "This is internal reasoning that should be filtered", - }, - { - type: "text", - text: "This is the response", - }, - ], - }, - { - role: "user", - content: "Continue", - }, - ] + // Should capture telemetry + expect(mockCaptureException).toHaveBeenCalled() + }) + + it("should handle reasoning stream parts", async () => { + const mockFullStream = (async function* () { + yield { type: "reasoning-delta", text: "Let me think..." } + yield { type: "text-delta", text: "The answer is 42." } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([{ type: "text", text: "Let me think...", signature: "sig123" }]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) - const stream = handler.createMessage(systemPrompt, messagesWithReasoning) + const stream = handler.createMessage(systemPrompt, mockMessages) const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - // Verify the API was called with filtered messages (no reasoning blocks) - const calledMessages = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0].messages - expect(calledMessages).toHaveLength(3) + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(1) + expect(reasoningChunks[0]).toEqual({ type: "reasoning", text: "Let me think..." }) + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0]).toEqual({ type: "text", text: "The answer is 42." }) + }) + + it("should capture thought signature from reasoning", async () => { + const mockFullStream = (async function* () { + yield { type: "reasoning-delta", text: "Thinking..." } + yield { type: "text-delta", text: "Answer" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([{ type: "text", text: "Thinking...", signature: "thought-sig-abc" }]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) - // Check assistant message - should have reasoning block filtered out - const assistantMessage = calledMessages.find((m: any) => m.role === "assistant") - expect(assistantMessage).toBeDefined() - expect(assistantMessage.content).toEqual([{ type: "text", text: "This is the response" }]) + const stream = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream) { + // consume + } - // Verify reasoning blocks were NOT sent to the API - expect(assistantMessage.content).not.toContainEqual(expect.objectContaining({ type: "reasoning" })) + expect(handler.getThoughtSignature()).toBe("thought-sig-abc") }) - it("should filter empty messages after removing all reasoning blocks", async () => { - handler = new AnthropicHandler({ - apiKey: "test-api-key", - apiModelId: "claude-3-5-sonnet-20241022", + it("should capture redacted thinking blocks from reasoning", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Answer" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([ + { type: "text", text: "Thinking...", signature: "sig1" }, + { type: "redacted", data: "base64redacteddata" }, + ]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - // Message with only reasoning content (should be completely filtered) - const messagesWithOnlyReasoning: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: "Hello", - }, - { - role: "assistant", - content: [ - { - type: "reasoning" as any, - text: "Only reasoning, no actual text", - }, - ], - }, - { - role: "user", - content: "Continue", - }, - ] + const stream = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream) { + // consume + } + + const redacted = handler.getRedactedThinkingBlocks() + expect(redacted).toEqual([{ type: "redacted_thinking", data: "base64redacteddata" }]) + }) + + it("should handle usage with cache tokens from provider metadata", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hi" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { cachedInputTokens: 30 }, + }), + providerMetadata: Promise.resolve({ + anthropic: { cacheCreationInputTokens: 20 }, + }), + }) - const stream = handler.createMessage(systemPrompt, messagesWithOnlyReasoning) + const stream = handler.createMessage(systemPrompt, mockMessages) const chunks: any[] = [] for await (const chunk of stream) { chunks.push(chunk) } - // Verify empty message was filtered out - const calledMessages = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0].messages - expect(calledMessages.length).toBe(2) // Only the two user messages - expect(calledMessages.every((m: any) => m.role === "user")).toBe(true) + const usageChunk = chunks.find((c) => c.type === "usage") + expect(usageChunk).toMatchObject({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 30, + cacheWriteTokens: 20, + }) + expect(usageChunk.totalCost).toBeDefined() }) - }) - describe("native tool calling", () => { - const systemPrompt = "You are a helpful assistant." - const messages: Anthropic.Messages.MessageParam[] = [ - { - role: "user", - content: [{ type: "text" as const, text: "What's the weather in London?" }], - }, - ] + it("should pass tools to streamText", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hi" } + })() - const mockTools = [ - { - type: "function" as const, - function: { - name: "get_weather", - description: "Get the current weather", - parameters: { - type: "object", - properties: { - location: { type: "string" }, + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + + const mockTools = [ + { + type: "function" as const, + function: { + name: "get_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { location: { type: "string" } }, + required: ["location"], }, - required: ["location"], }, }, - }, - ] + ] - it("should include tools in request when tools are provided", async () => { - const stream = handler.createMessage(systemPrompt, messages, { + const stream = handler.createMessage(systemPrompt, mockMessages, { taskId: "test-task", tools: mockTools, }) - // Consume the stream to trigger the API call for await (const _chunk of stream) { - // Just consume + // consume } - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - name: "get_weather", - description: "Get the current weather", - input_schema: expect.objectContaining({ - type: "object", - properties: expect.objectContaining({ - location: { type: "string" }, - }), - }), - }), - ]), + tools: expect.objectContaining({ + get_weather: expect.any(Object), + }), }), - expect.anything(), ) }) - it("should include tools when tools are provided", async () => { - const xmlHandler = new AnthropicHandler({ - ...mockOptions, + it("should handle tool call stream parts", async () => { + const mockFullStream = (async function* () { + yield { type: "tool-input-start", id: "toolu_123", toolName: "get_weather" } + yield { type: "tool-input-delta", id: "toolu_123", delta: '{"location":' } + yield { type: "tool-input-delta", id: "toolu_123", delta: '"London"}' } + yield { type: "tool-input-end", id: "toolu_123" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - const stream = xmlHandler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - }) + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks: any[] = [] - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume + for await (const chunk of stream) { + chunks.push(chunk) } - // Tool calling is request-driven: if tools are provided, we should include them. - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - name: "get_weather", - }), - ]), - }), - expect.anything(), - ) - }) - - it("should always include tools in request (tools are always present after PR #10841)", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", + const startChunk = chunks.find((c) => c.type === "tool_call_start") + expect(startChunk).toEqual({ + type: "tool_call_start", + id: "toolu_123", + name: "get_weather", }) - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume - } + const deltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + expect(deltaChunks).toHaveLength(2) + expect(deltaChunks[0]).toEqual({ + type: "tool_call_delta", + id: "toolu_123", + delta: '{"location":', + }) - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.any(Array), - tool_choice: expect.any(Object), - }), - expect.anything(), - ) + const endChunk = chunks.find((c) => c.type === "tool_call_end") + expect(endChunk).toEqual({ + type: "tool_call_end", + id: "toolu_123", + }) }) - it("should convert tool_choice 'auto' to Anthropic format", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - tool_choice: "auto", + it("should reset reasoning state on each call", async () => { + // First call with signature + const mockFullStream1 = (async function* () { + yield { type: "text-delta", text: "First" } + })() + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream1, + reasoning: Promise.resolve([{ type: "text", text: "Think", signature: "sig1" }]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume + const stream1 = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream1) { + // consume } + expect(handler.getThoughtSignature()).toBe("sig1") + + // Second call without signature + const mockFullStream2 = (async function* () { + yield { type: "text-delta", text: "Second" } + })() + + mockStreamText.mockReturnValueOnce({ + fullStream: mockFullStream2, + reasoning: Promise.resolve([]), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: { type: "auto", disable_parallel_tool_use: false }, - }), - expect.anything(), - ) + const stream2 = handler.createMessage(systemPrompt, mockMessages) + for await (const _chunk of stream2) { + // consume + } + expect(handler.getThoughtSignature()).toBeUndefined() + expect(handler.getRedactedThinkingBlocks()).toBeUndefined() }) + }) - it("should convert tool_choice 'required' to Anthropic 'any' format", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - tool_choice: "required", + describe("completePrompt", () => { + it("should complete prompt successfully", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test response", }) - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume - } + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith( + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - tool_choice: { type: "any", disable_parallel_tool_use: false }, + prompt: "Test prompt", + temperature: 0, }), - expect.anything(), ) }) - it("should set tool_choice to undefined when tool_choice is 'none' (tools are still passed)", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - tool_choice: "none", - }) - - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume - } + it("should handle API errors", async () => { + const mockError = new Error("Anthropic completion error: API Error") + mockGenerateText.mockRejectedValue(mockError) + await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Anthropic completion error: API Error") + expect(mockCaptureException).toHaveBeenCalled() + }) - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - // When tool_choice is 'none', the converter returns undefined for tool_choice - // but tools are still passed since they're always present - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.any(Array), - tool_choice: undefined, - }), - expect.anything(), - ) + it("should handle empty response", async () => { + mockGenerateText.mockResolvedValue({ + text: "", + }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") }) - it("should convert specific tool_choice to Anthropic 'tool' format", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - tool_choice: { type: "function" as const, function: { name: "get_weather" } }, + it("should handle undefined text", async () => { + mockGenerateText.mockResolvedValue({ + text: undefined, }) + const result = await handler.completePrompt("Test prompt") + expect(result).toBe("") + }) + }) - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume - } + describe("getModel", () => { + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new AnthropicHandler({ + ...mockOptions, + apiModelId: undefined, + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBeDefined() + expect(model.info).toBeDefined() + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: { type: "tool", name: "get_weather", disable_parallel_tool_use: false }, - }), - expect.anything(), - ) + it("should return specified model if valid model ID is provided", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.apiModelId) + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(200_000) + expect(model.info.supportsImages).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) }) - it("should enable parallel tool calls when parallelToolCalls is true", async () => { - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, - tool_choice: "auto", - parallelToolCalls: true, + it("honors custom maxTokens for thinking models", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-3-7-sonnet-20250219:thinking", + modelMaxTokens: 32_768, + modelMaxThinkingTokens: 16_384, }) - // Consume the stream to trigger the API call - for await (const _chunk of stream) { - // Just consume - } - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: { type: "auto", disable_parallel_tool_use: false }, - }), - expect.anything(), - ) + const result = handler.getModel() + expect(result.maxTokens).toBe(32_768) + expect(result.reasoningBudget).toEqual(16_384) + expect(result.temperature).toBe(1.0) }) - it("should handle tool_use blocks in stream and emit tool_call_partial", async () => { - mockCreate.mockImplementationOnce(async () => ({ - async *[Symbol.asyncIterator]() { - yield { - type: "message_start", - message: { - usage: { - input_tokens: 100, - output_tokens: 50, - }, - }, - } - yield { - type: "content_block_start", - index: 0, - content_block: { - type: "tool_use", - id: "toolu_123", - name: "get_weather", - }, - } - }, - })) - - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, + it("does not honor custom maxTokens for non-thinking models", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-3-7-sonnet-20250219", + modelMaxTokens: 32_768, + modelMaxThinkingTokens: 16_384, }) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + const result = handler.getModel() + expect(result.maxTokens).toBe(8192) + expect(result.reasoningBudget).toBeUndefined() + expect(result.temperature).toBe(0) + }) - // Find the tool_call_partial chunk - const toolCallChunk = chunks.find((chunk) => chunk.type === "tool_call_partial") - expect(toolCallChunk).toBeDefined() - expect(toolCallChunk).toEqual({ - type: "tool_call_partial", - index: 0, - id: "toolu_123", - name: "get_weather", - arguments: undefined, + it("should handle Claude 4.5 Sonnet model correctly", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-sonnet-4-5", }) + const model = handler.getModel() + expect(model.id).toBe("claude-sonnet-4-5") + expect(model.info.maxTokens).toBe(64000) + expect(model.info.contextWindow).toBe(200000) + expect(model.info.supportsReasoningBudget).toBe(true) }) - it("should handle input_json_delta in stream and emit tool_call_partial arguments", async () => { - mockCreate.mockImplementationOnce(async () => ({ - async *[Symbol.asyncIterator]() { - yield { - type: "message_start", - message: { - usage: { - input_tokens: 100, - output_tokens: 50, - }, - }, - } - yield { - type: "content_block_start", - index: 0, - content_block: { - type: "tool_use", - id: "toolu_123", - name: "get_weather", - }, - } - yield { - type: "content_block_delta", - index: 0, - delta: { - type: "input_json_delta", - partial_json: '{"location":', - }, - } - yield { - type: "content_block_delta", - index: 0, - delta: { - type: "input_json_delta", - partial_json: '"London"}', - }, - } - yield { - type: "content_block_stop", - index: 0, - } - }, - })) - - // Handler uses native protocol by default - const stream = handler.createMessage(systemPrompt, messages, { - taskId: "test-task", - tools: mockTools, + it("should enable 1M context for Claude 4.5 Sonnet when beta flag is set", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-sonnet-4-5", + anthropicBeta1MContext: true, }) + const model = handler.getModel() + expect(model.info.contextWindow).toBe(1000000) + expect(model.info.inputPrice).toBe(6.0) + expect(model.info.outputPrice).toBe(22.5) + }) - const chunks: any[] = [] - for await (const chunk of stream) { - chunks.push(chunk) - } - - // Find the tool_call_partial chunks - const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - expect(toolCallChunks).toHaveLength(3) - - // First chunk has id and name - expect(toolCallChunks[0]).toEqual({ - type: "tool_call_partial", - index: 0, - id: "toolu_123", - name: "get_weather", - arguments: undefined, + it("should strip :thinking suffix from model ID", () => { + const handler = new AnthropicHandler({ + apiKey: "test-api-key", + apiModelId: "claude-3-7-sonnet-20250219:thinking", }) + const model = handler.getModel() + expect(model.id).toBe("claude-3-7-sonnet-20250219") + expect(model.betas).toContain("output-128k-2025-02-19") + }) + }) - // Subsequent chunks have arguments - expect(toolCallChunks[1]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '{"location":', - }) + describe("isAiSdkProvider", () => { + it("should return true", () => { + expect(handler.isAiSdkProvider()).toBe(true) + }) + }) - expect(toolCallChunks[2]).toEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"London"}', - }) + describe("getThoughtSignature", () => { + it("should return undefined before any call", () => { + expect(handler.getThoughtSignature()).toBeUndefined() + }) + }) + + describe("getRedactedThinkingBlocks", () => { + it("should return undefined before any call", () => { + expect(handler.getRedactedThinkingBlocks()).toBeUndefined() }) }) }) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index fc6cc048c7e..ecadc0beaca 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -1,7 +1,6 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" -import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources" -import OpenAI from "openai" +import type { Anthropic } from "@anthropic-ai/sdk" +import { createAnthropic, type AnthropicProvider } from "@ai-sdk/anthropic" +import { streamText, generateText, ToolSet } from "ai" import { type ModelInfo, @@ -15,34 +14,67 @@ import { TelemetryService } from "@roo-code/telemetry" import type { ApiHandlerOptions } from "../../shared/api" -import { ApiStream } from "../transform/stream" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" +import type { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import { filterNonAnthropicBlocks } from "../transform/anthropic-filter" -import { handleProviderError } from "./utils/error-handler" +import { addAiSdkAnthropicCacheBreakpoints } from "../transform/caching/ai-sdk-anthropic" -import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" +import { DEFAULT_HEADERS } from "./constants" import { calculateApiCostAnthropic } from "../../shared/cost" -import { - convertOpenAIToolsToAnthropic, - convertOpenAIToolChoiceToAnthropic, -} from "../../core/prompts/tools/native-tools/converters" + +/** + * Models that support Anthropic prompt caching. + * These models require the `prompt-caching-2024-07-31` beta header. + */ +const CACHE_SUPPORTED_MODELS = new Set([ + "claude-sonnet-4-5", + "claude-sonnet-4-20250514", + "claude-opus-4-6", + "claude-opus-4-5-20251101", + "claude-opus-4-1-20250805", + "claude-opus-4-20250514", + "claude-3-7-sonnet-20250219", + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", + "claude-haiku-4-5-20251001", + "claude-3-haiku-20240307", +]) + +/** + * Models that support the 1M context beta. + */ +const CONTEXT_1M_MODELS = new Set(["claude-sonnet-4-20250514", "claude-sonnet-4-5", "claude-opus-4-6"]) export class AnthropicHandler extends BaseProvider implements SingleCompletionHandler { private options: ApiHandlerOptions - private client: Anthropic + private provider: AnthropicProvider private readonly providerName = "Anthropic" + private lastThoughtSignature: string | undefined + private lastRedactedBlocks: Array<{ type: "redacted_thinking"; data: string }> | undefined constructor(options: ApiHandlerOptions) { super() this.options = options - const apiKeyFieldName = - this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey" + const useAuthToken = !!(this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken) - this.client = new Anthropic({ + const headers: Record = { ...DEFAULT_HEADERS } + if (useAuthToken && this.options.apiKey) { + headers["Authorization"] = `Bearer ${this.options.apiKey}` + } + + this.provider = createAnthropic({ + apiKey: useAuthToken ? "" : (this.options.apiKey ?? "not-provided"), baseURL: this.options.anthropicBaseUrl || undefined, - [apiKeyFieldName]: this.options.apiKey, + headers, }) } @@ -51,9 +83,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - let stream: AnthropicStream - const cacheControl: CacheControlEphemeral = { type: "ephemeral" } - let { + const { id: modelId, betas = ["fine-grained-tool-streaming-2025-05-14"], maxTokens, @@ -61,271 +91,120 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa reasoning: thinking, } = this.getModel() - // Filter out non-Anthropic blocks (reasoning, thoughtSignature, etc.) before sending to the API - const sanitizedMessages = filterNonAnthropicBlocks(messages) - - // Add 1M context beta flag if enabled for supported models (Claude Sonnet 4/4.5, Opus 4.6) - if ( - (modelId === "claude-sonnet-4-20250514" || - modelId === "claude-sonnet-4-5" || - modelId === "claude-opus-4-6") && - this.options.anthropicBeta1MContext - ) { - betas.push("context-1m-2025-08-07") + // Build beta headers + const betaHeaders = [...betas] + + if (CACHE_SUPPORTED_MODELS.has(modelId)) { + betaHeaders.push("prompt-caching-2024-07-31") } - const nativeToolParams = { - tools: convertOpenAIToolsToAnthropic(metadata?.tools ?? []), - tool_choice: convertOpenAIToolChoiceToAnthropic(metadata?.tool_choice, metadata?.parallelToolCalls), + if (CONTEXT_1M_MODELS.has(modelId) && this.options.anthropicBeta1MContext) { + betaHeaders.push("context-1m-2025-08-07") } - switch (modelId) { - case "claude-sonnet-4-5": - case "claude-sonnet-4-20250514": - case "claude-opus-4-6": - case "claude-opus-4-5-20251101": - case "claude-opus-4-1-20250805": - case "claude-opus-4-20250514": - case "claude-3-7-sonnet-20250219": - case "claude-3-5-sonnet-20241022": - case "claude-3-5-haiku-20241022": - case "claude-3-opus-20240229": - case "claude-haiku-4-5-20251001": - case "claude-3-haiku-20240307": { - /** - * The latest message will be the new user message, one before - * will be the assistant message from a previous request, and - * the user message before that will be a previously cached user - * message. So we need to mark the latest user message as - * ephemeral to cache it for the next request, and mark the - * second to last user message as ephemeral to let the server - * know the last message to retrieve from the cache for the - * current request. - */ - const userMsgIndices = sanitizedMessages.reduce( - (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), - [] as number[], - ) - - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 - const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 - - try { - stream = await this.client.messages.create( - { - model: modelId, - max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, - temperature, - thinking, - // Setting cache breakpoint for system prompt so new tasks can reuse it. - system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }], - messages: sanitizedMessages.map((message, index) => { - if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { - return { - ...message, - content: - typeof message.content === "string" - ? [{ type: "text", text: message.content, cache_control: cacheControl }] - : message.content.map((content, contentIndex) => - contentIndex === message.content.length - 1 - ? { ...content, cache_control: cacheControl } - : content, - ), - } + // Convert messages to AI SDK format (handles filtering of reasoning/thinking/etc. blocks) + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Add cache breakpoints to the last 2 user messages + const useCache = CACHE_SUPPORTED_MODELS.has(modelId) + const cachedMessages = useCache ? addAiSdkAnthropicCacheBreakpoints(aiSdkMessages) : aiSdkMessages + + // Convert tools to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + const toolChoice = mapToolChoice(metadata?.tool_choice) + + // Map Anthropic thinking config from snake_case to camelCase for AI SDK + const thinkingProviderOptions = thinking + ? { + thinking: + thinking.type === "enabled" + ? { + type: "enabled" as const, + budgetTokens: (thinking as { budget_tokens: number }).budget_tokens, } - return message - }), - stream: true, - ...nativeToolParams, - }, - (() => { - // prompt caching: https://x.com/alexalbert__/status/1823751995901272068 - // https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers - // https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393 - - // Then check for models that support prompt caching - switch (modelId) { - case "claude-sonnet-4-5": - case "claude-sonnet-4-20250514": - case "claude-opus-4-6": - case "claude-opus-4-5-20251101": - case "claude-opus-4-1-20250805": - case "claude-opus-4-20250514": - case "claude-3-7-sonnet-20250219": - case "claude-3-5-sonnet-20241022": - case "claude-3-5-haiku-20241022": - case "claude-3-opus-20240229": - case "claude-haiku-4-5-20251001": - case "claude-3-haiku-20240307": - betas.push("prompt-caching-2024-07-31") - return { headers: { "anthropic-beta": betas.join(",") } } - default: - return undefined - } - })(), - ) - } catch (error) { - TelemetryService.instance.captureException( - new ApiProviderError( - error instanceof Error ? error.message : String(error), - this.providerName, - modelId, - "createMessage", - ), - ) - throw error - } - break - } - default: { - try { - stream = (await this.client.messages.create({ - model: modelId, - max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, - temperature, - system: [{ text: systemPrompt, type: "text" }], - messages: sanitizedMessages, - stream: true, - ...nativeToolParams, - })) as any - } catch (error) { - TelemetryService.instance.captureException( - new ApiProviderError( - error instanceof Error ? error.message : String(error), - this.providerName, - modelId, - "createMessage", - ), - ) - throw error + : thinking, } - break - } + : undefined + + // Build system prompt — with cache control for supported models + // Cast to any to bypass strict typing: the AI SDK Anthropic provider accepts + // text parts with providerOptions at runtime for system prompt caching. + const system: any = useCache + ? [ + { + type: "text" as const, + text: systemPrompt, + providerOptions: { anthropic: { cacheControl: { type: "ephemeral" } } }, + }, + ] + : systemPrompt + + // Build the request options + const requestOptions: Parameters[0] = { + model: this.provider(modelId), + system, + messages: cachedMessages, + maxOutputTokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS, + temperature, + tools: aiSdkTools, + toolChoice, + headers: { "anthropic-beta": betaHeaders.join(",") }, + ...(thinkingProviderOptions && { + providerOptions: { anthropic: thinkingProviderOptions } as any, + }), } - let inputTokens = 0 - let outputTokens = 0 - let cacheWriteTokens = 0 - let cacheReadTokens = 0 - - for await (const chunk of stream) { - switch (chunk.type) { - case "message_start": { - // Tells us cache reads/writes/input/output. - const { - input_tokens = 0, - output_tokens = 0, - cache_creation_input_tokens, - cache_read_input_tokens, - } = chunk.message.usage - - yield { - type: "usage", - inputTokens: input_tokens, - outputTokens: output_tokens, - cacheWriteTokens: cache_creation_input_tokens || undefined, - cacheReadTokens: cache_read_input_tokens || undefined, - } + try { + // Reset reasoning state for this request + this.lastThoughtSignature = undefined + this.lastRedactedBlocks = undefined - inputTokens += input_tokens - outputTokens += output_tokens - cacheWriteTokens += cache_creation_input_tokens || 0 - cacheReadTokens += cache_read_input_tokens || 0 + const result = streamText(requestOptions) - break + // Process the full stream + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } - case "message_delta": - // Tells us stop_reason, stop_sequence, and output tokens - // along the way and at the end of the message. - yield { - type: "usage", - inputTokens: 0, - outputTokens: chunk.usage.output_tokens || 0, - } + } - break - case "message_stop": - // No usage data, just an indicator that the message is done. - break - case "content_block_start": - switch (chunk.content_block.type) { - case "thinking": - // We may receive multiple text blocks, in which - // case just insert a line break between them. - if (chunk.index > 0) { - yield { type: "reasoning", text: "\n" } - } - - yield { type: "reasoning", text: chunk.content_block.thinking } - break - case "text": - // We may receive multiple text blocks, in which - // case just insert a line break between them. - if (chunk.index > 0) { - yield { type: "text", text: "\n" } - } - - yield { type: "text", text: chunk.content_block.text } - break - case "tool_use": { - // Emit initial tool call partial with id and name - yield { - type: "tool_call_partial", - index: chunk.index, - id: chunk.content_block.id, - name: chunk.content_block.name, - arguments: undefined, - } - break - } + // After stream completes, capture reasoning data for signatures and redacted thinking + const reasoning = await result.reasoning + if (reasoning && Array.isArray(reasoning)) { + for (const entry of reasoning) { + // The AI SDK types reasoning parts as { type: "reasoning" } but the + // Anthropic provider returns richer types at runtime including "text" + // (with signature) and "redacted" (with data). Use any cast. + const entryAny = entry as any + if (entryAny.type === "text" && entryAny.signature) { + this.lastThoughtSignature = entryAny.signature } - break - case "content_block_delta": - switch (chunk.delta.type) { - case "thinking_delta": - yield { type: "reasoning", text: chunk.delta.thinking } - break - case "text_delta": - yield { type: "text", text: chunk.delta.text } - break - case "input_json_delta": { - // Emit tool call partial chunks as arguments stream in - yield { - type: "tool_call_partial", - index: chunk.index, - id: undefined, - name: undefined, - arguments: chunk.delta.partial_json, - } - break + if (entryAny.type === "redacted" && entryAny.data) { + if (!this.lastRedactedBlocks) { + this.lastRedactedBlocks = [] } + this.lastRedactedBlocks.push({ + type: "redacted_thinking", + data: entryAny.data, + }) } - - break - case "content_block_stop": - // Block complete - no action needed for now. - // NativeToolCallParser handles tool call completion - // Note: Signature for multi-turn thinking would require using stream.finalMessage() - // after iteration completes, which requires restructuring the streaming approach. - break + } } - } - if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) { - const { totalCost } = calculateApiCostAnthropic( - this.getModel().info, - inputTokens, - outputTokens, - cacheWriteTokens, - cacheReadTokens, - ) + // Yield usage metrics at the end + const usage = await result.usage + const providerMetadata = await result.providerMetadata - yield { - type: "usage", - inputTokens: 0, - outputTokens: 0, - totalCost, + if (usage) { + yield this.processUsageMetrics(usage, this.getModel().info, providerMetadata) } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + TelemetryService.instance.captureException( + new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage"), + ) + throw error } } @@ -335,11 +214,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa let info: ModelInfo = anthropicModels[id] // If 1M context beta is enabled for supported models, update the model info - if ( - (id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5" || id === "claude-opus-4-6") && - this.options.anthropicBeta1MContext - ) { - // Use the tier pricing for 1M context + if (CONTEXT_1M_MODELS.has(id) && this.options.anthropicBeta1MContext) { const tier = info.tiers?.[0] if (tier) { info = { @@ -373,31 +248,87 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } async completePrompt(prompt: string) { - let { id: model, temperature } = this.getModel() + const { id: model, temperature } = this.getModel() - let message try { - message = await this.client.messages.create({ - model, - max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS, - thinking: undefined, + const result = await generateText({ + model: this.provider(model), + prompt, + maxOutputTokens: ANTHROPIC_DEFAULT_MAX_TOKENS, temperature, - messages: [{ role: "user", content: prompt }], - stream: false, }) + + return result.text ?? "" } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) TelemetryService.instance.captureException( - new ApiProviderError( - error instanceof Error ? error.message : String(error), - this.providerName, - model, - "completePrompt", - ), + new ApiProviderError(errorMessage, this.providerName, model, "completePrompt"), ) throw error } + } + + /** + * Process usage metrics from the AI SDK response. + * Handles Anthropic-specific cache tokens from providerMetadata. + */ + private processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + info: ModelInfo, + providerMetadata?: Record, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens || 0 + const outputTokens = usage.outputTokens || 0 + const cacheReadTokens = usage.details?.cachedInputTokens + + // Cache write tokens come from Anthropic-specific provider metadata + const anthropicMeta = providerMetadata?.anthropic as { cacheCreationInputTokens?: number } | undefined + const cacheWriteTokens = anthropicMeta?.cacheCreationInputTokens + + const { totalCost } = calculateApiCostAnthropic( + info, + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + ) + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + totalCost, + } + } + + override isAiSdkProvider(): boolean { + return true + } + + /** + * Returns the thought signature captured from the last Anthropic response. + * Anthropic extended thinking returns a signature on thinking blocks + * that must be round-tripped for tool use continuations. + */ + getThoughtSignature(): string | undefined { + return this.lastThoughtSignature + } - const content = message.content.find(({ type }) => type === "text") - return content?.type === "text" ? content.text : "" + /** + * Returns redacted thinking blocks from the last Anthropic response. + * These blocks are returned when safety filters trigger on reasoning content + * and must be passed back verbatim for proper reasoning continuity. + */ + getRedactedThinkingBlocks(): Array<{ type: "redacted_thinking"; data: string }> | undefined { + return this.lastRedactedBlocks } } diff --git a/src/api/transform/__tests__/ai-sdk.spec.ts b/src/api/transform/__tests__/ai-sdk.spec.ts index 3c1ca6d87e5..f69f97c3a92 100644 --- a/src/api/transform/__tests__/ai-sdk.spec.ts +++ b/src/api/transform/__tests__/ai-sdk.spec.ts @@ -349,7 +349,7 @@ describe("AI SDK conversion utilities", () => { expect(result[0]).toEqual({ role: "assistant", content: [ - { type: "reasoning", text: "Deep thought" }, + { type: "reasoning", text: "Deep thought", signature: "sig" }, { type: "text", text: "OK" }, ], }) diff --git a/src/api/transform/ai-sdk.ts b/src/api/transform/ai-sdk.ts index 9b48ee57f79..f664332d024 100644 --- a/src/api/transform/ai-sdk.ts +++ b/src/api/transform/ai-sdk.ts @@ -126,7 +126,8 @@ export function convertToAiSdkMessages( } } else if (message.role === "assistant") { const textParts: string[] = [] - const reasoningParts: string[] = [] + const reasoningEntries: Array<{ text: string; signature?: string }> = [] + const redactedReasoningEntries: Array<{ data: string }> = [] const reasoningContent = (() => { const maybe = (message as unknown as { reasoning_content?: unknown }).reasoning_content return typeof maybe === "string" && maybe.length > 0 ? maybe : undefined @@ -188,7 +189,7 @@ export function convertToAiSdkMessages( const text = (part as unknown as { text?: string }).text if (typeof text === "string" && text.length > 0) { - reasoningParts.push(text) + reasoningEntries.push({ text }) } continue } @@ -196,16 +197,31 @@ export function convertToAiSdkMessages( if ((part as unknown as { type?: string }).type === "thinking") { if (reasoningContent) continue - const thinking = (part as unknown as { thinking?: string }).thinking + const partAny2 = part as unknown as { thinking?: string; signature?: string } + const thinking = partAny2.thinking if (typeof thinking === "string" && thinking.length > 0) { - reasoningParts.push(thinking) + reasoningEntries.push({ + text: thinking, + ...(partAny2.signature ? { signature: partAny2.signature } : {}), + }) + } + continue + } + + // Anthropic redacted_thinking blocks must be round-tripped verbatim. + // The AI SDK represents these as { type: "redacted-reasoning", data: "..." }. + if ((part as unknown as { type?: string }).type === "redacted_thinking") { + const data = (part as unknown as { data?: string }).data + if (typeof data === "string") { + redactedReasoningEntries.push({ data }) } continue } } const content: Array< - | { type: "reasoning"; text: string } + | { type: "reasoning"; text: string; signature?: string } + | { type: "redacted-reasoning"; data: string } | { type: "text"; text: string } | { type: "tool-call" @@ -218,8 +234,26 @@ export function convertToAiSdkMessages( if (reasoningContent) { content.push({ type: "reasoning", text: reasoningContent }) - } else if (reasoningParts.length > 0) { - content.push({ type: "reasoning", text: reasoningParts.join("") }) + } else { + // When any entry carries a signature (Anthropic extended thinking), + // keep entries separate so signatures are preserved for round-tripping. + const hasSignatures = reasoningEntries.some((e) => e.signature) + + if (hasSignatures) { + for (const entry of reasoningEntries) { + content.push({ + type: "reasoning", + text: entry.text, + ...(entry.signature ? { signature: entry.signature } : {}), + }) + } + } else if (reasoningEntries.length > 0) { + content.push({ type: "reasoning", text: reasoningEntries.map((e) => e.text).join("") }) + } + + for (const entry of redactedReasoningEntries) { + content.push({ type: "redacted-reasoning", data: entry.data }) + } } if (textParts.length > 0) { @@ -416,6 +450,16 @@ export function* processAiSdkStreamPart(part: ExtendedStreamPart): Generator { + it("should return messages unchanged when there are no user messages", () => { + const messages: ModelMessage[] = [{ role: "assistant", content: [{ type: "text", text: "Hello" }] }] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + expect(result).toEqual(messages) + }) + + it("should add cache breakpoint to a single user message with string content", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: [{ type: "text", text: "Hi" }] }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + expect(result[0]).toEqual({ + role: "user", + content: [{ type: "text", text: "Hello", providerOptions: CACHE_CONTROL }], + }) + }) + + it("should add cache breakpoints to the last two user messages", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "First" }, + { role: "assistant", content: [{ type: "text", text: "Response 1" }] }, + { role: "user", content: "Second" }, + { role: "assistant", content: [{ type: "text", text: "Response 2" }] }, + { role: "user", content: "Third" }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + // First user message should NOT have cache control + expect(result[0]).toEqual({ role: "user", content: "First" }) + + // Second user message should have cache control + expect(result[2]).toEqual({ + role: "user", + content: [{ type: "text", text: "Second", providerOptions: CACHE_CONTROL }], + }) + + // Third user message should have cache control + expect(result[4]).toEqual({ + role: "user", + content: [{ type: "text", text: "Third", providerOptions: CACHE_CONTROL }], + }) + }) + + it("should add cache breakpoint to the last text part of array content", () => { + const messages: ModelMessage[] = [ + { + role: "user", + content: [ + { type: "text", text: "First part" }, + { type: "image", image: "data:image/png;base64,..." }, + { type: "text", text: "Last text part" }, + ], + }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + expect((result[0] as any).content).toEqual([ + { type: "text", text: "First part" }, + { type: "image", image: "data:image/png;base64,..." }, + { type: "text", text: "Last text part", providerOptions: CACHE_CONTROL }, + ]) + }) + + it("should add placeholder text part when no text parts exist in array content", () => { + const messages: ModelMessage[] = [ + { + role: "user", + content: [{ type: "image", image: "data:image/png;base64,..." }], + }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + expect((result[0] as any).content).toEqual([ + { type: "image", image: "data:image/png;base64,..." }, + { type: "text", text: "...", providerOptions: CACHE_CONTROL }, + ]) + }) + + it("should not mutate the original messages", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: [{ type: "text", text: "Hi" }] }, + ] + + const original = JSON.parse(JSON.stringify(messages)) + addAiSdkAnthropicCacheBreakpoints(messages) + + expect(messages).toEqual(original) + }) + + it("should handle both user messages when only two exist", () => { + const messages: ModelMessage[] = [ + { role: "user", content: "First" }, + { role: "assistant", content: [{ type: "text", text: "Response" }] }, + { role: "user", content: "Second" }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + expect(result[0]).toEqual({ + role: "user", + content: [{ type: "text", text: "First", providerOptions: CACHE_CONTROL }], + }) + expect(result[2]).toEqual({ + role: "user", + content: [{ type: "text", text: "Second", providerOptions: CACHE_CONTROL }], + }) + }) + + it("should not modify assistant or tool messages", () => { + const assistantMsg: ModelMessage = { role: "assistant", content: [{ type: "text", text: "Response" }] } + const toolMsg: ModelMessage = { + role: "tool", + content: [ + { + type: "tool-result", + toolCallId: "call1", + toolName: "test", + output: { type: "text", value: "result" }, + }, + ], + } as ModelMessage + + const messages: ModelMessage[] = [ + { role: "user", content: "Hello" }, + assistantMsg, + toolMsg, + { role: "user", content: "Continue" }, + ] + + const result = addAiSdkAnthropicCacheBreakpoints(messages) + + expect(result[1]).toEqual(assistantMsg) + expect(result[2]).toEqual(toolMsg) + }) +}) diff --git a/src/api/transform/caching/ai-sdk-anthropic.ts b/src/api/transform/caching/ai-sdk-anthropic.ts new file mode 100644 index 00000000000..f8fdd823451 --- /dev/null +++ b/src/api/transform/caching/ai-sdk-anthropic.ts @@ -0,0 +1,90 @@ +import type { ModelMessage } from "ai" + +const ANTHROPIC_CACHE_CONTROL = { + anthropic: { cacheControl: { type: "ephemeral" } }, +} + +/** + * Add Anthropic cache breakpoints to AI SDK ModelMessage array. + * Adds `providerOptions.anthropic.cacheControl` to the last text content part + * of the last 2 user messages, enabling prompt caching for Anthropic models + * via the AI SDK. + * + * Note: System prompt caching is handled separately at the streamText call level + * by passing the system prompt as an array with providerOptions. + * + * @param messages - Array of AI SDK ModelMessage objects + * @returns New array with cache breakpoints added (does not mutate input) + */ +export function addAiSdkAnthropicCacheBreakpoints(messages: ModelMessage[]): ModelMessage[] { + const userMsgIndices = messages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + + const targetIndices = new Set(userMsgIndices.slice(-2)) + + if (targetIndices.size === 0) { + return messages + } + + return messages.map((message, index) => { + if (!targetIndices.has(index)) { + return message + } + + if (typeof message.content === "string") { + return { + ...message, + content: [ + { + type: "text" as const, + text: message.content, + providerOptions: ANTHROPIC_CACHE_CONTROL, + }, + ], + } as ModelMessage + } + + if (Array.isArray(message.content)) { + // Find the index of the last text part + let lastTextIndex = -1 + for (let i = message.content.length - 1; i >= 0; i--) { + if ((message.content[i] as { type: string }).type === "text") { + lastTextIndex = i + break + } + } + + if (lastTextIndex === -1) { + // No text part found — add a placeholder + return { + ...message, + content: [ + ...message.content, + { + type: "text" as const, + text: "...", + providerOptions: ANTHROPIC_CACHE_CONTROL, + }, + ], + } as ModelMessage + } + + return { + ...message, + content: message.content.map((part, i) => { + if (i === lastTextIndex) { + return { + ...(part as Record), + providerOptions: ANTHROPIC_CACHE_CONTROL, + } + } + return part + }), + } as ModelMessage + } + + return message + }) +} diff --git a/src/package.json b/src/package.json index 3e0201c6412..0a638c62535 100644 --- a/src/package.json +++ b/src/package.json @@ -450,6 +450,7 @@ "clean": "rimraf README.md CHANGELOG.md LICENSE dist logs mock .turbo" }, "dependencies": { + "@ai-sdk/anthropic": "^3.0.37", "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26",