From 6e090e7112a123e9a38aba296c79ac36238e544e Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 6 Feb 2026 20:26:06 +0000 Subject: [PATCH 1/2] feat: add heuristic-based model routing for cost optimization Adds an experimental feature that dynamically routes API calls to a lighter/cheaper model when the task is in an information-gathering phase. Addresses Issue #11269 - Choose Model Dynamically Based on Request. ## How it works - New experiment flag: modelRouting (disabled by default) - New setting: modelRoutingLightModelId - the cheaper model ID to use - ModelRouter tracks tool usage per API turn - If previous turn only used "read" group tools (read_file, list_files, search_files, codebase_search), the next API call uses the light model - Edit, command, browser, and MCP tools always use the primary model - First turn always uses the primary model ## Changes - packages/types/src/experiment.ts: Add modelRouting experiment ID - packages/types/src/global-settings.ts: Add modelRoutingLightModelId - src/shared/experiments.ts: Add MODEL_ROUTING config - src/core/task/ModelRouter.ts: New heuristic-based model router - src/core/task/Task.ts: Integrate ModelRouter into task lifecycle - src/core/task/__tests__/ModelRouter.spec.ts: 35 tests (all passing) --- packages/types/src/experiment.ts | 9 +- packages/types/src/global-settings.ts | 7 + packages/types/src/vscode-extension-host.ts | 1 + src/core/task/ModelRouter.ts | 233 +++++++++++++++++ src/core/task/Task.ts | 33 +++ src/core/task/__tests__/ModelRouter.spec.ts | 266 ++++++++++++++++++++ src/core/webview/ClineProvider.ts | 1 + src/shared/__tests__/experiments.spec.ts | 3 + src/shared/experiments.ts | 2 + 9 files changed, 554 insertions(+), 1 deletion(-) create mode 100644 src/core/task/ModelRouter.ts create mode 100644 src/core/task/__tests__/ModelRouter.spec.ts diff --git a/packages/types/src/experiment.ts b/packages/types/src/experiment.ts index d7eb0b03d6c..86ce1438fb8 100644 --- a/packages/types/src/experiment.ts +++ b/packages/types/src/experiment.ts @@ -6,7 +6,13 @@ import type { Keys, Equals, AssertEqual } from "./type-fu.js" * ExperimentId */ -export const experimentIds = ["preventFocusDisruption", "imageGeneration", "runSlashCommand", "customTools"] as const +export const experimentIds = [ + "preventFocusDisruption", + "imageGeneration", + "runSlashCommand", + "customTools", + "modelRouting", +] as const export const experimentIdsSchema = z.enum(experimentIds) @@ -21,6 +27,7 @@ export const experimentsSchema = z.object({ imageGeneration: z.boolean().optional(), runSlashCommand: z.boolean().optional(), customTools: z.boolean().optional(), + modelRouting: z.boolean().optional(), }) export type Experiments = z.infer diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 11b9fe148d1..b57a42bbd80 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -232,6 +232,13 @@ export const globalSettingsSchema = z.object({ * @default true */ showWorktreesInHomeScreen: z.boolean().optional(), + + /** + * The model ID to use for "light" tasks when model routing is enabled. + * Must be a model available from the same provider as the primary model. + * Requires the "modelRouting" experiment to be enabled. + */ + modelRoutingLightModelId: z.string().optional(), }) export type GlobalSettings = z.infer diff --git a/packages/types/src/vscode-extension-host.ts b/packages/types/src/vscode-extension-host.ts index fa2f04c0e5d..560c7fc4dcb 100644 --- a/packages/types/src/vscode-extension-host.ts +++ b/packages/types/src/vscode-extension-host.ts @@ -334,6 +334,7 @@ export type ExtensionState = Pick< | "maxGitStatusFiles" | "requestDelaySeconds" | "showWorktreesInHomeScreen" + | "modelRoutingLightModelId" > & { version: string clineMessages: ClineMessage[] diff --git a/src/core/task/ModelRouter.ts b/src/core/task/ModelRouter.ts new file mode 100644 index 00000000000..0ebc2fec382 --- /dev/null +++ b/src/core/task/ModelRouter.ts @@ -0,0 +1,233 @@ +import type { ToolName, ProviderSettings, Experiments, ToolGroup } from "@roo-code/types" + +import { modelIdKeysByProvider, isTypicalProvider } from "@roo-code/types" + +import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../shared/tools" +import { EXPERIMENT_IDS, experiments } from "../../shared/experiments" +import { buildApiHandler, type ApiHandler } from "../../api" + +/** + * Tool complexity tier used for model routing decisions. + * + * - "light": Information-gathering tools (read_file, list_files, search_files, codebase_search) + * - "standard": All other tools (edit, command, browser, mcp, etc.) + */ +export type ModelTier = "light" | "standard" + +/** + * Set of tool groups considered "light" for routing purposes. + * Turns that only use tools from these groups (or always-available tools) + * are eligible for routing to a cheaper model. + */ +const LIGHT_TOOL_GROUPS: ReadonlySet = new Set(["read"]) + +/** + * Build a reverse map from tool name to tool group. + */ +function buildToolToGroupMap(): Map { + const map = new Map() + for (const [groupName, groupConfig] of Object.entries(TOOL_GROUPS)) { + for (const tool of groupConfig.tools) { + map.set(tool, groupName as ToolGroup) + } + if (groupConfig.customTools) { + for (const tool of groupConfig.customTools) { + map.set(tool, groupName as ToolGroup) + } + } + } + return map +} + +const TOOL_TO_GROUP = buildToolToGroupMap() + +/** + * Always-available tools as a Set for fast lookup. + * These tools (ask_followup_question, attempt_completion, update_todo_list, etc.) + * do not affect model tier classification. + */ +const ALWAYS_AVAILABLE_SET = new Set(ALWAYS_AVAILABLE_TOOLS) + +/** + * Determines the tool group for a given tool name. + * Returns undefined for always-available tools (which are tier-neutral). + */ +export function getToolGroup(toolName: string): ToolGroup | undefined { + if (ALWAYS_AVAILABLE_SET.has(toolName)) { + return undefined // Tier-neutral + } + return TOOL_TO_GROUP.get(toolName) +} + +/** + * Classifies a set of tool names into a model tier. + * + * - If no tools were used → "standard" (pure reasoning, needs full model) + * - If ALL tools are in light groups or always-available → "light" + * - If ANY tool is in a non-light group → "standard" + */ +export function classifyToolUsage(toolNames: ReadonlySet): ModelTier { + if (toolNames.size === 0) { + return "standard" + } + + let hasLightTool = false + + for (const toolName of toolNames) { + const group = getToolGroup(toolName) + + if (group === undefined) { + // Always-available tool, does not affect classification + continue + } + + if (LIGHT_TOOL_GROUPS.has(group)) { + hasLightTool = true + } else { + // Found a non-light tool, immediately classify as standard + return "standard" + } + } + + // If we only found light tools (and possibly always-available ones), classify as light + return hasLightTool ? "light" : "standard" +} + +/** + * ModelRouter provides heuristic-based model routing for cost optimization. + * + * It tracks which tools were used in each API turn and uses that information + * to decide whether the next API call should use a lighter (cheaper) model + * or the primary (more capable) model. + * + * ## Heuristic (v1) + * - First turn: always use primary model + * - If previous turn only used "read" group tools: use light model + * - If previous turn used edit/command/browser/mcp tools: use primary model + * - If previous turn had no tool calls (pure reasoning): use primary model + * + * ## Usage + * ```typescript + * const router = new ModelRouter() + * + * // Before each API call + * if (router.shouldUseLightModel()) { + * // Use light model + * } + * + * // During tool execution + * router.recordToolUse("read_file") + * + * // After API turn completes + * router.endTurn() + * ``` + */ +export class ModelRouter { + /** Tools used in the current (ongoing) turn */ + private currentTurnTools: Set = new Set() + + /** Classification of the previous (completed) turn */ + private previousTurnTier: ModelTier = "standard" + + /** Whether at least one turn has completed */ + private hasPreviousTurn = false + + /** + * Record that a tool was used in the current turn. + */ + recordToolUse(toolName: ToolName): void { + this.currentTurnTools.add(toolName) + } + + /** + * Signal that the current turn has completed. + * Moves current turn's tool usage to the "previous turn" classification. + */ + endTurn(): void { + this.previousTurnTier = classifyToolUsage(this.currentTurnTools) + this.currentTurnTools = new Set() + this.hasPreviousTurn = true + } + + /** + * Check whether the next API call should use the light model. + * + * Returns true only if: + * - At least one turn has completed (never on first turn) + * - The previous turn was classified as "light" + */ + shouldUseLightModel(): boolean { + return this.hasPreviousTurn && this.previousTurnTier === "light" + } + + /** + * Get the current tier classification (for debugging/logging). + */ + getCurrentTier(): ModelTier { + return this.hasPreviousTurn ? this.previousTurnTier : "standard" + } + + /** + * Reset the router state (e.g., when task is restarted). + */ + reset(): void { + this.currentTurnTools = new Set() + this.previousTurnTier = "standard" + this.hasPreviousTurn = false + } + + /** + * Check if model routing is enabled based on experiment settings and configuration. + * + * @param experimentsConfig - The experiments configuration + * @param lightModelId - The light model ID from settings + * @returns true if model routing is fully configured and enabled + */ + static isEnabled(experimentsConfig: Experiments | undefined, lightModelId: string | undefined): boolean { + if (!experimentsConfig || !lightModelId || lightModelId.trim() === "") { + return false + } + return experiments.isEnabled(experimentsConfig, EXPERIMENT_IDS.MODEL_ROUTING as any) + } + + /** + * Build a ProviderSettings with the light model ID substituted in place of + * the primary model ID. The provider and all other settings remain the same. + * + * @param baseConfig - The primary provider settings + * @param lightModelId - The model ID to use for light tasks + * @returns A new ProviderSettings with the light model, or null if the provider + * is not supported for model routing + */ + static buildLightModelConfig(baseConfig: ProviderSettings, lightModelId: string): ProviderSettings | null { + const provider = baseConfig.apiProvider + if (!provider || !isTypicalProvider(provider)) { + return null + } + + const modelIdKey = modelIdKeysByProvider[provider] + if (!modelIdKey) { + return null + } + + return { + ...baseConfig, + [modelIdKey]: lightModelId, + } + } + + /** + * Build an ApiHandler configured for the light model. + * + * @param baseConfig - The primary provider settings + * @param lightModelId - The model ID to use for light tasks + * @returns An ApiHandler for the light model, or null if routing is not possible + */ + static buildLightModelHandler(baseConfig: ProviderSettings, lightModelId: string): ApiHandler | null { + const lightConfig = ModelRouter.buildLightModelConfig(baseConfig, lightModelId) + if (!lightConfig) { + return null + } + return buildApiHandler(lightConfig) + } +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index d5e9aa0cfb6..22cd577868d 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -96,6 +96,7 @@ import { getTaskDirectoryPath } from "../../utils/storage" import { formatResponse } from "../prompts/responses" import { SYSTEM_PROMPT } from "../prompts/system" import { buildNativeToolsArrayWithRestrictions } from "./build-tools" +import { ModelRouter } from "./ModelRouter" // core modules import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector" @@ -297,6 +298,7 @@ export class Task extends EventEmitter implements TaskLike { } toolRepetitionDetector: ToolRepetitionDetector + modelRouter: ModelRouter rooIgnoreController?: RooIgnoreController rooProtectedController?: RooProtectedController fileContextTracker: FileContextTracker @@ -614,6 +616,7 @@ export class Task extends EventEmitter implements TaskLike { this.apiConfiguration = apiConfiguration this.api = buildApiHandler(this.apiConfiguration) + this.modelRouter = new ModelRouter() this.autoApprovalHandler = new AutoApprovalHandler() this.urlContentFetcher = new UrlContentFetcher(provider.context) @@ -2894,6 +2897,28 @@ export class Task extends EventEmitter implements TaskLike { await this.diffViewProvider.reset() + // Model routing: temporarily swap to light model if heuristics say so + let primaryApiHandler: typeof this.api | undefined + { + const routingState = await this.providerRef.deref()?.getState() + if ( + ModelRouter.isEnabled(routingState?.experiments, routingState?.modelRoutingLightModelId) && + this.modelRouter.shouldUseLightModel() + ) { + const lightHandler = ModelRouter.buildLightModelHandler( + this.apiConfiguration, + routingState!.modelRoutingLightModelId!, + ) + if (lightHandler) { + primaryApiHandler = this.api + this.api = lightHandler + console.log( + `[Task#${this.taskId}] Model routing: using light model "${routingState!.modelRoutingLightModelId}" for this turn`, + ) + } + } + } + // Cache model info once per API request to avoid repeated calls during streaming // This is especially important for tools and background usage collection this.cachedStreamingModel = this.api.getModel() @@ -3598,6 +3623,13 @@ export class Task extends EventEmitter implements TaskLike { await pWaitFor(() => this.userMessageContentReady) + // Model routing: end current turn and restore primary handler + this.modelRouter.endTurn() + if (primaryApiHandler) { + this.api = primaryApiHandler + primaryApiHandler = undefined + } + // If the model did not tool use, then we need to tell it to // either use a tool or attempt_completion. const didToolUse = this.assistantMessageContent.some( @@ -4639,6 +4671,7 @@ export class Task extends EventEmitter implements TaskLike { } this.toolUsage[toolName].attempts++ + this.modelRouter.recordToolUse(toolName) } public recordToolError(toolName: ToolName, error?: string) { diff --git a/src/core/task/__tests__/ModelRouter.spec.ts b/src/core/task/__tests__/ModelRouter.spec.ts new file mode 100644 index 00000000000..c8b43c5f394 --- /dev/null +++ b/src/core/task/__tests__/ModelRouter.spec.ts @@ -0,0 +1,266 @@ +import type { ToolName, ProviderSettings, Experiments } from "@roo-code/types" + +import { ModelRouter, classifyToolUsage, getToolGroup, type ModelTier } from "../ModelRouter" + +describe("getToolGroup", () => { + it("returns 'read' for read group tools", () => { + expect(getToolGroup("read_file")).toBe("read") + expect(getToolGroup("search_files")).toBe("read") + expect(getToolGroup("list_files")).toBe("read") + expect(getToolGroup("codebase_search")).toBe("read") + }) + + it("returns 'edit' for edit group tools", () => { + expect(getToolGroup("apply_diff")).toBe("edit") + expect(getToolGroup("write_to_file")).toBe("edit") + }) + + it("returns 'command' for command group tools", () => { + expect(getToolGroup("execute_command")).toBe("command") + expect(getToolGroup("read_command_output")).toBe("command") + }) + + it("returns 'browser' for browser group tools", () => { + expect(getToolGroup("browser_action")).toBe("browser") + }) + + it("returns 'mcp' for mcp group tools", () => { + expect(getToolGroup("use_mcp_tool")).toBe("mcp") + expect(getToolGroup("access_mcp_resource")).toBe("mcp") + }) + + it("returns undefined for always-available tools including mode tools (tier-neutral)", () => { + // switch_mode and new_task are in ALWAYS_AVAILABLE_TOOLS, so they are tier-neutral + expect(getToolGroup("switch_mode")).toBeUndefined() + expect(getToolGroup("new_task")).toBeUndefined() + expect(getToolGroup("ask_followup_question")).toBeUndefined() + expect(getToolGroup("attempt_completion")).toBeUndefined() + expect(getToolGroup("update_todo_list")).toBeUndefined() + expect(getToolGroup("run_slash_command")).toBeUndefined() + expect(getToolGroup("skill")).toBeUndefined() + }) +}) + +describe("classifyToolUsage", () => { + it('returns "standard" when no tools were used', () => { + expect(classifyToolUsage(new Set())).toBe("standard") + }) + + it('returns "light" when only read tools were used', () => { + expect(classifyToolUsage(new Set(["read_file"]))).toBe("light") + expect(classifyToolUsage(new Set(["read_file", "search_files"]))).toBe("light") + expect(classifyToolUsage(new Set(["read_file", "list_files", "codebase_search"]))).toBe("light") + }) + + it('returns "light" when read tools and always-available tools are used', () => { + expect(classifyToolUsage(new Set(["read_file", "ask_followup_question"]))).toBe("light") + expect(classifyToolUsage(new Set(["search_files", "update_todo_list"]))).toBe("light") + }) + + it('returns "standard" when only always-available tools are used (no read tools)', () => { + // Only always-available tools - no light tools present, so "standard" + expect(classifyToolUsage(new Set(["ask_followup_question"]))).toBe("standard") + expect(classifyToolUsage(new Set(["update_todo_list"]))).toBe("standard") + expect(classifyToolUsage(new Set(["attempt_completion"]))).toBe("standard") + }) + + it('returns "standard" when any edit tool is used', () => { + expect(classifyToolUsage(new Set(["read_file", "apply_diff"]))).toBe("standard") + expect(classifyToolUsage(new Set(["write_to_file"]))).toBe("standard") + }) + + it('returns "standard" when any command tool is used', () => { + expect(classifyToolUsage(new Set(["read_file", "execute_command"]))).toBe("standard") + }) + + it('returns "standard" when any browser tool is used', () => { + expect(classifyToolUsage(new Set(["read_file", "browser_action"]))).toBe("standard") + }) + + it('returns "standard" when any mcp tool is used', () => { + expect(classifyToolUsage(new Set(["use_mcp_tool"]))).toBe("standard") + }) +}) + +describe("ModelRouter", () => { + let router: ModelRouter + + beforeEach(() => { + router = new ModelRouter() + }) + + describe("shouldUseLightModel", () => { + it("returns false on first turn (no previous turn)", () => { + expect(router.shouldUseLightModel()).toBe(false) + }) + + it("returns false after first turn with no tools", () => { + router.endTurn() + expect(router.shouldUseLightModel()).toBe(false) + }) + + it("returns true after a turn with only read tools", () => { + router.recordToolUse("read_file" as ToolName) + router.recordToolUse("search_files" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(true) + }) + + it("returns false after a turn with edit tools", () => { + router.recordToolUse("read_file" as ToolName) + router.recordToolUse("apply_diff" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(false) + }) + + it("returns false after a turn with command tools", () => { + router.recordToolUse("execute_command" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(false) + }) + + it("returns true when read tools + always-available tools used", () => { + router.recordToolUse("read_file" as ToolName) + router.recordToolUse("update_todo_list" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(true) + }) + + it("tracks multiple turns correctly", () => { + // Turn 1: read only -> next should use light + router.recordToolUse("read_file" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(true) + + // Turn 2: edit -> next should use standard + router.recordToolUse("write_to_file" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(false) + + // Turn 3: read only again -> next should use light + router.recordToolUse("list_files" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(true) + }) + }) + + describe("getCurrentTier", () => { + it('returns "standard" before any turn completes', () => { + expect(router.getCurrentTier()).toBe("standard") + }) + + it('returns "light" after a read-only turn', () => { + router.recordToolUse("read_file" as ToolName) + router.endTurn() + expect(router.getCurrentTier()).toBe("light") + }) + }) + + describe("reset", () => { + it("resets router state to initial", () => { + router.recordToolUse("read_file" as ToolName) + router.endTurn() + expect(router.shouldUseLightModel()).toBe(true) + + router.reset() + expect(router.shouldUseLightModel()).toBe(false) + expect(router.getCurrentTier()).toBe("standard") + }) + }) + + describe("isEnabled", () => { + it("returns false when experiments is undefined", () => { + expect(ModelRouter.isEnabled(undefined, "some-model")).toBe(false) + }) + + it("returns false when lightModelId is undefined", () => { + const experiments: Experiments = { modelRouting: true } + expect(ModelRouter.isEnabled(experiments, undefined)).toBe(false) + }) + + it("returns false when lightModelId is empty string", () => { + const experiments: Experiments = { modelRouting: true } + expect(ModelRouter.isEnabled(experiments, "")).toBe(false) + expect(ModelRouter.isEnabled(experiments, " ")).toBe(false) + }) + + it("returns false when experiment is not enabled", () => { + const experiments: Experiments = { modelRouting: false } + expect(ModelRouter.isEnabled(experiments, "some-model")).toBe(false) + }) + + it("returns true when experiment is enabled and lightModelId is set", () => { + const experiments: Experiments = { modelRouting: true } + expect(ModelRouter.isEnabled(experiments, "claude-3-haiku-20241022")).toBe(true) + }) + }) + + describe("buildLightModelConfig", () => { + it("returns null for unsupported provider types", () => { + const config: ProviderSettings = { + apiProvider: "fake-ai" as any, + } + expect(ModelRouter.buildLightModelConfig(config, "some-model")).toBeNull() + }) + + it("returns null when apiProvider is not set", () => { + const config: ProviderSettings = {} + expect(ModelRouter.buildLightModelConfig(config, "some-model")).toBeNull() + }) + + it("creates config with light model ID for anthropic provider", () => { + const config: ProviderSettings = { + apiProvider: "anthropic", + apiModelId: "claude-sonnet-4-20250514", + apiKey: "test-key", + } + const result = ModelRouter.buildLightModelConfig(config, "claude-3-haiku-20241022") + expect(result).not.toBeNull() + expect(result!.apiProvider).toBe("anthropic") + expect(result!.apiModelId).toBe("claude-3-haiku-20241022") + expect(result!.apiKey).toBe("test-key") + }) + + it("creates config with light model ID for openrouter provider", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "anthropic/claude-sonnet-4-20250514", + openRouterApiKey: "test-key", + } + const result = ModelRouter.buildLightModelConfig(config, "anthropic/claude-3-haiku-20241022") + expect(result).not.toBeNull() + expect(result!.apiProvider).toBe("openrouter") + expect(result!.openRouterModelId).toBe("anthropic/claude-3-haiku-20241022") + expect(result!.openRouterApiKey).toBe("test-key") + }) + + it("creates config with light model ID for gemini provider", () => { + const config: ProviderSettings = { + apiProvider: "gemini", + apiModelId: "gemini-2.5-pro", + geminiApiKey: "test-key", + } + const result = ModelRouter.buildLightModelConfig(config, "gemini-2.0-flash") + expect(result).not.toBeNull() + expect(result!.apiProvider).toBe("gemini") + expect(result!.apiModelId).toBe("gemini-2.0-flash") + expect(result!.geminiApiKey).toBe("test-key") + }) + + it("preserves all other settings from base config", () => { + const config: ProviderSettings = { + apiProvider: "anthropic", + apiModelId: "claude-sonnet-4-20250514", + apiKey: "test-key", + modelTemperature: 0.5, + enableReasoningEffort: true, + reasoningEffort: "medium", + } + const result = ModelRouter.buildLightModelConfig(config, "claude-3-haiku-20241022") + expect(result).not.toBeNull() + expect(result!.modelTemperature).toBe(0.5) + expect(result!.enableReasoningEffort).toBe(true) + expect(result!.reasoningEffort).toBe("medium") + }) + }) +}) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 84cc76825f7..3ec53c81ae4 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -2411,6 +2411,7 @@ export class ClineProvider customSupportPrompts: stateValues.customSupportPrompts ?? {}, enhancementApiConfigId: stateValues.enhancementApiConfigId, experiments: stateValues.experiments ?? experimentDefault, + modelRoutingLightModelId: stateValues.modelRoutingLightModelId, autoApprovalEnabled: stateValues.autoApprovalEnabled ?? false, customModes, maxOpenTabsContext: stateValues.maxOpenTabsContext ?? 20, diff --git a/src/shared/__tests__/experiments.spec.ts b/src/shared/__tests__/experiments.spec.ts index 92a7d7604ff..905a565a65a 100644 --- a/src/shared/__tests__/experiments.spec.ts +++ b/src/shared/__tests__/experiments.spec.ts @@ -21,6 +21,7 @@ describe("experiments", () => { imageGeneration: false, runSlashCommand: false, customTools: false, + modelRouting: false, } expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.PREVENT_FOCUS_DISRUPTION)).toBe(false) }) @@ -31,6 +32,7 @@ describe("experiments", () => { imageGeneration: false, runSlashCommand: false, customTools: false, + modelRouting: false, } expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.PREVENT_FOCUS_DISRUPTION)).toBe(true) }) @@ -41,6 +43,7 @@ describe("experiments", () => { imageGeneration: false, runSlashCommand: false, customTools: false, + modelRouting: false, } expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.PREVENT_FOCUS_DISRUPTION)).toBe(false) }) diff --git a/src/shared/experiments.ts b/src/shared/experiments.ts index e189f99e23d..6986cb80a3b 100644 --- a/src/shared/experiments.ts +++ b/src/shared/experiments.ts @@ -5,6 +5,7 @@ export const EXPERIMENT_IDS = { IMAGE_GENERATION: "imageGeneration", RUN_SLASH_COMMAND: "runSlashCommand", CUSTOM_TOOLS: "customTools", + MODEL_ROUTING: "modelRouting", } as const satisfies Record type _AssertExperimentIds = AssertEqual>> @@ -20,6 +21,7 @@ export const experimentConfigsMap: Record = { IMAGE_GENERATION: { enabled: false }, RUN_SLASH_COMMAND: { enabled: false }, CUSTOM_TOOLS: { enabled: false }, + MODEL_ROUTING: { enabled: false }, } export const experimentDefault = Object.fromEntries( From 37c8e2ae3f5ca0f89fbfc66e250e7049f03b0c4c Mon Sep 17 00:00:00 2001 From: Roo Code Date: Sun, 8 Feb 2026 00:57:36 +0000 Subject: [PATCH 2/2] feat: add model routing UI settings, fix api handler restoration bug - Add translation keys for MODEL_ROUTING in locales/en/settings.json - Create ModelRoutingSettings component with toggle and light model ID input - Wire up MODEL_ROUTING special-case in ExperimentalSettings.tsx - Add modelRoutingLightModelId state binding in SettingsView.tsx - Fix bug: restore this.api to primary handler on all paths (stream failure, empty-response retry, catch) not just the happy path - Fix nit: remove unnecessary as any cast in ModelRouter.isEnabled() --- src/core/task/ModelRouter.ts | 2 +- src/core/task/Task.ts | 23 ++++++--- .../settings/ExperimentalSettings.tsx | 23 +++++++++ .../settings/ModelRoutingSettings.tsx | 50 +++++++++++++++++++ .../src/components/settings/SettingsView.tsx | 14 ++++++ webview-ui/src/i18n/locales/en/settings.json | 6 +++ 6 files changed, 111 insertions(+), 7 deletions(-) create mode 100644 webview-ui/src/components/settings/ModelRoutingSettings.tsx diff --git a/src/core/task/ModelRouter.ts b/src/core/task/ModelRouter.ts index 0ebc2fec382..bc35a2761bf 100644 --- a/src/core/task/ModelRouter.ts +++ b/src/core/task/ModelRouter.ts @@ -187,7 +187,7 @@ export class ModelRouter { if (!experimentsConfig || !lightModelId || lightModelId.trim() === "") { return false } - return experiments.isEnabled(experimentsConfig, EXPERIMENT_IDS.MODEL_ROUTING as any) + return experiments.isEnabled(experimentsConfig, EXPERIMENT_IDS.MODEL_ROUTING) } /** diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 22cd577868d..d282034c48a 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -2795,6 +2795,9 @@ export class Task extends EventEmitter implements TaskLike { await this.saveClineMessages() await this.providerRef.deref()?.postStateToWebviewWithoutTaskHistory() + // Model routing: declare outside try block so catch can restore it + let primaryApiHandler: typeof this.api | undefined + try { let cacheWriteTokens = 0 let cacheReadTokens = 0 @@ -2898,7 +2901,6 @@ export class Task extends EventEmitter implements TaskLike { await this.diffViewProvider.reset() // Model routing: temporarily swap to light model if heuristics say so - let primaryApiHandler: typeof this.api | undefined { const routingState = await this.providerRef.deref()?.getState() if ( @@ -3604,6 +3606,14 @@ export class Task extends EventEmitter implements TaskLike { presentAssistantMessage(this) } + // Model routing: always restore primary API handler after streaming completes, + // regardless of whether the turn had content or not. This prevents permanently + // losing the primary model on empty-response retries or error paths. + if (primaryApiHandler) { + this.api = primaryApiHandler + primaryApiHandler = undefined + } + if (hasTextContent || hasToolUses) { // NOTE: This comment is here for future reference - this was a // workaround for `userMessageContent` not getting set to true. @@ -3623,12 +3633,8 @@ export class Task extends EventEmitter implements TaskLike { await pWaitFor(() => this.userMessageContentReady) - // Model routing: end current turn and restore primary handler + // Model routing: end current turn classification this.modelRouter.endTurn() - if (primaryApiHandler) { - this.api = primaryApiHandler - primaryApiHandler = undefined - } // If the model did not tool use, then we need to tell it to // either use a tool or attempt_completion. @@ -3770,6 +3776,11 @@ export class Task extends EventEmitter implements TaskLike { // If we reach here without continuing, return false (will always be false for now) return false } catch (error) { + // Model routing: restore primary API handler on error paths + if (primaryApiHandler) { + this.api = primaryApiHandler + primaryApiHandler = undefined + } // This should never happen since the only thing that can throw an // error is the attemptApiRequest, which is wrapped in a try catch // that sends an ask where if noButtonClicked, will clear current diff --git a/webview-ui/src/components/settings/ExperimentalSettings.tsx b/webview-ui/src/components/settings/ExperimentalSettings.tsx index 23786ce0b98..9ca2e24c9cf 100644 --- a/webview-ui/src/components/settings/ExperimentalSettings.tsx +++ b/webview-ui/src/components/settings/ExperimentalSettings.tsx @@ -14,6 +14,7 @@ import { SearchableSetting } from "./SearchableSetting" import { ExperimentalFeature } from "./ExperimentalFeature" import { ImageGenerationSettings } from "./ImageGenerationSettings" import { CustomToolsSettings } from "./CustomToolsSettings" +import { ModelRoutingSettings } from "./ModelRoutingSettings" type ExperimentalSettingsProps = HTMLAttributes & { experiments: Experiments @@ -26,6 +27,8 @@ type ExperimentalSettingsProps = HTMLAttributes & { setImageGenerationProvider?: (provider: ImageGenerationProvider) => void setOpenRouterImageApiKey?: (apiKey: string) => void setImageGenerationSelectedModel?: (model: string) => void + modelRoutingLightModelId?: string + setModelRoutingLightModelId?: (modelId: string) => void } export const ExperimentalSettings = ({ @@ -39,6 +42,8 @@ export const ExperimentalSettings = ({ setImageGenerationProvider, setOpenRouterImageApiKey, setImageGenerationSelectedModel, + modelRoutingLightModelId, + setModelRoutingLightModelId, className, ...props }: ExperimentalSettingsProps) => { @@ -83,6 +88,24 @@ export const ExperimentalSettings = ({ ) } + if (config[0] === "MODEL_ROUTING" && setModelRoutingLightModelId) { + return ( + + + setExperimentEnabled(EXPERIMENT_IDS.MODEL_ROUTING, enabled) + } + modelRoutingLightModelId={modelRoutingLightModelId} + setModelRoutingLightModelId={setModelRoutingLightModelId} + /> + + ) + } if (config[0] === "CUSTOM_TOOLS") { return ( void + modelRoutingLightModelId: string | undefined + setModelRoutingLightModelId: (modelId: string) => void +} + +export const ModelRoutingSettings = ({ + enabled, + onChange, + modelRoutingLightModelId, + setModelRoutingLightModelId, +}: ModelRoutingSettingsProps) => { + const { t } = useAppTranslation() + + return ( +
+
+
+ onChange(e.target.checked)}> + {t("settings:experimental.MODEL_ROUTING.name")} + +
+

+ {t("settings:experimental.MODEL_ROUTING.description")} +

+
+ + {enabled && ( +
+
+ + setModelRoutingLightModelId(e.target.value)} + className="w-full" + /> +
+
+ )} +
+ ) +} diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx index 1876302b472..06470956835 100644 --- a/webview-ui/src/components/settings/SettingsView.tsx +++ b/webview-ui/src/components/settings/SettingsView.tsx @@ -207,6 +207,7 @@ const SettingsView = forwardRef(({ onDone, t imageGenerationProvider, openRouterImageApiKey, openRouterImageGenerationSelectedModel, + modelRoutingLightModelId, reasoningBlockCollapsed, enterBehavior, includeCurrentTime, @@ -338,6 +339,16 @@ const SettingsView = forwardRef(({ onDone, t }) }, []) + const setModelRoutingLightModelId = useCallback((modelId: string) => { + setCachedState((prevState) => { + if (prevState.modelRoutingLightModelId !== modelId) { + setChangeDetected(true) + } + + return { ...prevState, modelRoutingLightModelId: modelId } + }) + }, []) + const setCustomSupportPromptsField = useCallback((prompts: Record) => { setCachedState((prevState) => { const previousStr = JSON.stringify(prevState.customSupportPrompts) @@ -422,6 +433,7 @@ const SettingsView = forwardRef(({ onDone, t imageGenerationProvider, openRouterImageApiKey, openRouterImageGenerationSelectedModel, + modelRoutingLightModelId, experiments, customSupportPrompts, }, @@ -927,6 +939,8 @@ const SettingsView = forwardRef(({ onDone, t setImageGenerationProvider={setImageGenerationProvider} setOpenRouterImageApiKey={setOpenRouterImageApiKey} setImageGenerationSelectedModel={setImageGenerationSelectedModel} + modelRoutingLightModelId={modelRoutingLightModelId} + setModelRoutingLightModelId={setModelRoutingLightModelId} /> )} diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index 61dfaf42af5..becd7ff223f 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -941,6 +941,12 @@ "refreshSuccess": "Tools refreshed successfully", "refreshError": "Failed to refresh tools", "toolParameters": "Parameters" + }, + "MODEL_ROUTING": { + "name": "Enable model routing", + "description": "When enabled, Roo dynamically routes API calls to a lighter (cheaper) model during information-gathering phases. The light model must be from the same provider as your primary model.", + "lightModelIdLabel": "Light Model ID", + "lightModelIdPlaceholder": "e.g. claude-3-haiku-20241022" } }, "promptCaching": {