Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/types/src/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import type { Keys, Equals, AssertEqual } from "./type-fu.js"
* ExperimentId
*/

export const experimentIds = ["preventFocusDisruption", "imageGeneration", "runSlashCommand", "customTools"] as const
export const experimentIds = [
"preventFocusDisruption",
"imageGeneration",
"runSlashCommand",
"customTools",
"modelRouting",
] as const

export const experimentIdsSchema = z.enum(experimentIds)

Expand All @@ -21,6 +27,7 @@ export const experimentsSchema = z.object({
imageGeneration: z.boolean().optional(),
runSlashCommand: z.boolean().optional(),
customTools: z.boolean().optional(),
modelRouting: z.boolean().optional(),
})

export type Experiments = z.infer<typeof experimentsSchema>
Expand Down
7 changes: 7 additions & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ export const globalSettingsSchema = z.object({
* @default true
*/
showWorktreesInHomeScreen: z.boolean().optional(),

/**
* The model ID to use for "light" tasks when model routing is enabled.
* Must be a model available from the same provider as the primary model.
* Requires the "modelRouting" experiment to be enabled.
*/
modelRoutingLightModelId: z.string().optional(),
})

export type GlobalSettings = z.infer<typeof globalSettingsSchema>
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/vscode-extension-host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ export type ExtensionState = Pick<
| "maxGitStatusFiles"
| "requestDelaySeconds"
| "showWorktreesInHomeScreen"
| "modelRoutingLightModelId"
> & {
version: string
clineMessages: ClineMessage[]
Expand Down
233 changes: 233 additions & 0 deletions src/core/task/ModelRouter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import type { ToolName, ProviderSettings, Experiments, ToolGroup } from "@roo-code/types"

import { modelIdKeysByProvider, isTypicalProvider } from "@roo-code/types"

import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../shared/tools"
import { EXPERIMENT_IDS, experiments } from "../../shared/experiments"
import { buildApiHandler, type ApiHandler } from "../../api"

/**
* Tool complexity tier used for model routing decisions.
*
* - "light": Information-gathering tools (read_file, list_files, search_files, codebase_search)
* - "standard": All other tools (edit, command, browser, mcp, etc.)
*/
export type ModelTier = "light" | "standard"

/**
* Set of tool groups considered "light" for routing purposes.
* Turns that only use tools from these groups (or always-available tools)
* are eligible for routing to a cheaper model.
*/
const LIGHT_TOOL_GROUPS: ReadonlySet<ToolGroup> = new Set<ToolGroup>(["read"])

/**
* Build a reverse map from tool name to tool group.
*/
function buildToolToGroupMap(): Map<string, ToolGroup> {
const map = new Map<string, ToolGroup>()
for (const [groupName, groupConfig] of Object.entries(TOOL_GROUPS)) {
for (const tool of groupConfig.tools) {
map.set(tool, groupName as ToolGroup)
}
if (groupConfig.customTools) {
for (const tool of groupConfig.customTools) {
map.set(tool, groupName as ToolGroup)
}
}
}
return map
}

const TOOL_TO_GROUP = buildToolToGroupMap()

/**
* Always-available tools as a Set for fast lookup.
* These tools (ask_followup_question, attempt_completion, update_todo_list, etc.)
* do not affect model tier classification.
*/
const ALWAYS_AVAILABLE_SET = new Set<string>(ALWAYS_AVAILABLE_TOOLS)

/**
* Determines the tool group for a given tool name.
* Returns undefined for always-available tools (which are tier-neutral).
*/
export function getToolGroup(toolName: string): ToolGroup | undefined {
if (ALWAYS_AVAILABLE_SET.has(toolName)) {
return undefined // Tier-neutral
}
return TOOL_TO_GROUP.get(toolName)
}

/**
* Classifies a set of tool names into a model tier.
*
* - If no tools were used → "standard" (pure reasoning, needs full model)
* - If ALL tools are in light groups or always-available → "light"
* - If ANY tool is in a non-light group → "standard"
*/
export function classifyToolUsage(toolNames: ReadonlySet<string>): ModelTier {
if (toolNames.size === 0) {
return "standard"
}

let hasLightTool = false

for (const toolName of toolNames) {
const group = getToolGroup(toolName)

if (group === undefined) {
// Always-available tool, does not affect classification
continue
}

if (LIGHT_TOOL_GROUPS.has(group)) {
hasLightTool = true
} else {
// Found a non-light tool, immediately classify as standard
return "standard"
}
}

// If we only found light tools (and possibly always-available ones), classify as light
return hasLightTool ? "light" : "standard"
}

/**
* ModelRouter provides heuristic-based model routing for cost optimization.
*
* It tracks which tools were used in each API turn and uses that information
* to decide whether the next API call should use a lighter (cheaper) model
* or the primary (more capable) model.
*
* ## Heuristic (v1)
* - First turn: always use primary model
* - If previous turn only used "read" group tools: use light model
* - If previous turn used edit/command/browser/mcp tools: use primary model
* - If previous turn had no tool calls (pure reasoning): use primary model
*
* ## Usage
* ```typescript
* const router = new ModelRouter()
*
* // Before each API call
* if (router.shouldUseLightModel()) {
* // Use light model
* }
*
* // During tool execution
* router.recordToolUse("read_file")
*
* // After API turn completes
* router.endTurn()
* ```
*/
export class ModelRouter {
/** Tools used in the current (ongoing) turn */
private currentTurnTools: Set<string> = new Set()

/** Classification of the previous (completed) turn */
private previousTurnTier: ModelTier = "standard"

/** Whether at least one turn has completed */
private hasPreviousTurn = false

/**
* Record that a tool was used in the current turn.
*/
recordToolUse(toolName: ToolName): void {
this.currentTurnTools.add(toolName)
}

/**
* Signal that the current turn has completed.
* Moves current turn's tool usage to the "previous turn" classification.
*/
endTurn(): void {
this.previousTurnTier = classifyToolUsage(this.currentTurnTools)
this.currentTurnTools = new Set()
this.hasPreviousTurn = true
}

/**
* Check whether the next API call should use the light model.
*
* Returns true only if:
* - At least one turn has completed (never on first turn)
* - The previous turn was classified as "light"
*/
shouldUseLightModel(): boolean {
return this.hasPreviousTurn && this.previousTurnTier === "light"
}

/**
* Get the current tier classification (for debugging/logging).
*/
getCurrentTier(): ModelTier {
return this.hasPreviousTurn ? this.previousTurnTier : "standard"
}

/**
* Reset the router state (e.g., when task is restarted).
*/
reset(): void {
this.currentTurnTools = new Set()
this.previousTurnTier = "standard"
this.hasPreviousTurn = false
}

/**
* Check if model routing is enabled based on experiment settings and configuration.
*
* @param experimentsConfig - The experiments configuration
* @param lightModelId - The light model ID from settings
* @returns true if model routing is fully configured and enabled
*/
static isEnabled(experimentsConfig: Experiments | undefined, lightModelId: string | undefined): boolean {
if (!experimentsConfig || !lightModelId || lightModelId.trim() === "") {
return false
}
return experiments.isEnabled(experimentsConfig, EXPERIMENT_IDS.MODEL_ROUTING)
}

/**
* Build a ProviderSettings with the light model ID substituted in place of
* the primary model ID. The provider and all other settings remain the same.
*
* @param baseConfig - The primary provider settings
* @param lightModelId - The model ID to use for light tasks
* @returns A new ProviderSettings with the light model, or null if the provider
* is not supported for model routing
*/
static buildLightModelConfig(baseConfig: ProviderSettings, lightModelId: string): ProviderSettings | null {
const provider = baseConfig.apiProvider
if (!provider || !isTypicalProvider(provider)) {
return null
}

const modelIdKey = modelIdKeysByProvider[provider]
if (!modelIdKey) {
return null
}

return {
...baseConfig,
[modelIdKey]: lightModelId,
}
}

/**
* Build an ApiHandler configured for the light model.
*
* @param baseConfig - The primary provider settings
* @param lightModelId - The model ID to use for light tasks
* @returns An ApiHandler for the light model, or null if routing is not possible
*/
static buildLightModelHandler(baseConfig: ProviderSettings, lightModelId: string): ApiHandler | null {
const lightConfig = ModelRouter.buildLightModelConfig(baseConfig, lightModelId)
if (!lightConfig) {
return null
}
return buildApiHandler(lightConfig)
}
}
44 changes: 44 additions & 0 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ import { getTaskDirectoryPath } from "../../utils/storage"
import { formatResponse } from "../prompts/responses"
import { SYSTEM_PROMPT } from "../prompts/system"
import { buildNativeToolsArrayWithRestrictions } from "./build-tools"
import { ModelRouter } from "./ModelRouter"

// core modules
import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector"
Expand Down Expand Up @@ -297,6 +298,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
}

toolRepetitionDetector: ToolRepetitionDetector
modelRouter: ModelRouter
rooIgnoreController?: RooIgnoreController
rooProtectedController?: RooProtectedController
fileContextTracker: FileContextTracker
Expand Down Expand Up @@ -614,6 +616,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {

this.apiConfiguration = apiConfiguration
this.api = buildApiHandler(this.apiConfiguration)
this.modelRouter = new ModelRouter()
this.autoApprovalHandler = new AutoApprovalHandler()

this.urlContentFetcher = new UrlContentFetcher(provider.context)
Expand Down Expand Up @@ -2792,6 +2795,9 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
await this.saveClineMessages()
await this.providerRef.deref()?.postStateToWebviewWithoutTaskHistory()

// Model routing: declare outside try block so catch can restore it
let primaryApiHandler: typeof this.api | undefined

try {
let cacheWriteTokens = 0
let cacheReadTokens = 0
Expand Down Expand Up @@ -2894,6 +2900,27 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {

await this.diffViewProvider.reset()

// Model routing: temporarily swap to light model if heuristics say so
{
const routingState = await this.providerRef.deref()?.getState()
if (
ModelRouter.isEnabled(routingState?.experiments, routingState?.modelRoutingLightModelId) &&
this.modelRouter.shouldUseLightModel()
) {
const lightHandler = ModelRouter.buildLightModelHandler(
this.apiConfiguration,
routingState!.modelRoutingLightModelId!,
)
if (lightHandler) {
primaryApiHandler = this.api
this.api = lightHandler
console.log(
`[Task#${this.taskId}] Model routing: using light model "${routingState!.modelRoutingLightModelId}" for this turn`,
)
}
}
}

// Cache model info once per API request to avoid repeated calls during streaming
// This is especially important for tools and background usage collection
this.cachedStreamingModel = this.api.getModel()
Expand Down Expand Up @@ -3579,6 +3606,14 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
presentAssistantMessage(this)
}

// Model routing: always restore primary API handler after streaming completes,
// regardless of whether the turn had content or not. This prevents permanently
// losing the primary model on empty-response retries or error paths.
if (primaryApiHandler) {
this.api = primaryApiHandler
primaryApiHandler = undefined
}

if (hasTextContent || hasToolUses) {
// NOTE: This comment is here for future reference - this was a
// workaround for `userMessageContent` not getting set to true.
Expand All @@ -3598,6 +3633,9 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {

await pWaitFor(() => this.userMessageContentReady)

// Model routing: end current turn classification
this.modelRouter.endTurn()

// If the model did not tool use, then we need to tell it to
// either use a tool or attempt_completion.
const didToolUse = this.assistantMessageContent.some(
Expand Down Expand Up @@ -3738,6 +3776,11 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
// If we reach here without continuing, return false (will always be false for now)
return false
} catch (error) {
// Model routing: restore primary API handler on error paths
if (primaryApiHandler) {
this.api = primaryApiHandler
primaryApiHandler = undefined
}
// This should never happen since the only thing that can throw an
// error is the attemptApiRequest, which is wrapped in a try catch
// that sends an ask where if noButtonClicked, will clear current
Expand Down Expand Up @@ -4639,6 +4682,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
}

this.toolUsage[toolName].attempts++
this.modelRouter.recordToolUse(toolName)
}

public recordToolError(toolName: ToolName, error?: string) {
Expand Down
Loading