1717run_agent_modal = modal .Function .from_name (app_name = "swebench-agent-run" , name = "run_agent_modal" )
1818
1919
20- async def process_batch_modal (examples : list [SweBenchExample ], run_id : str , num_workers = 5 , min_workers = 1 , max_retries = 3 ):
20+ async def process_batch_modal (examples : list [SweBenchExample ], run_id : str , model : str , num_workers = 5 , min_workers = 1 , max_retries = 3 ):
2121 """Process a batch of examples concurrently using a queue system with incremental worker scaling.
2222
2323 Args:
@@ -110,7 +110,7 @@ async def is_rate_limit_error(error):
110110
111111 async def process_example (example , attempt , current_task ):
112112 try :
113- result = await run_agent_modal .remote .aio (example , run_id = run_id )
113+ result = await run_agent_modal .remote .aio (example , run_id = run_id , model = model )
114114
115115 if result is None :
116116 print (f"Warning: Null result for { example .instance_id } " )
@@ -222,7 +222,7 @@ async def worker():
222222 return [results .get (example .instance_id , {"instance_id" : example .instance_id , "status" : "missing" }) for example in examples ]
223223
224224
225- def process_batch_local (examples : list [SweBenchExample ], num_workers = 5 , codebases : dict [str , Codebase ] = {}, run_id : str | None = None ):
225+ def process_batch_local (examples : list [SweBenchExample ], model : str , num_workers = 5 , codebases : dict [str , Codebase ] = {}, run_id : str | None = None ):
226226 """Process a batch of examples synchronously.
227227
228228 Args:
@@ -242,9 +242,9 @@ def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebase
242242 try :
243243 # Run the agent locally instead of using modal
244244 if codebases and example .instance_id in codebases :
245- result = run_agent_on_entry (example , codebase = codebases [example .instance_id ], run_id = run_id )
245+ result = run_agent_on_entry (example , model = model , codebase = codebases [example .instance_id ], run_id = run_id )
246246 else :
247- result = run_agent_on_entry (example , run_id = run_id )
247+ result = run_agent_on_entry (example , model = model , run_id = run_id )
248248 results .append (result )
249249
250250 except Exception as e :
@@ -267,7 +267,15 @@ def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebase
267267
268268
269269async def run_eval (
270- use_existing_preds : str | None , dataset : str , length : int , instance_id : str | None = None , local : bool = False , codebases : dict [str , Codebase ] = {}, repo : str | None = None , num_workers : int = 5
270+ use_existing_preds : str | None ,
271+ dataset : str ,
272+ length : int ,
273+ instance_id : str | None = None ,
274+ local : bool = False ,
275+ codebases : dict [str , Codebase ] = {},
276+ repo : str | None = None ,
277+ num_workers : int = 5 ,
278+ model : str = "claude-3-7-sonnet-latest" ,
271279):
272280 run_id = use_existing_preds or str (uuid .uuid4 ())
273281 print (f"Run ID: { run_id } " )
@@ -294,9 +302,9 @@ async def run_eval(
294302
295303 # Process all examples in parallel batches
296304 if local :
297- results = process_batch_local (examples , codebases = codebases , run_id = run_id )
305+ results = process_batch_local (examples , model = model , codebases = codebases , run_id = run_id )
298306 else :
299- results = await process_batch_modal (examples , num_workers = num_workers , run_id = run_id )
307+ results = await process_batch_modal (examples , model = model , run_id = run_id , num_workers = num_workers )
300308
301309 # Save individual results
302310 for result in results :
@@ -355,9 +363,11 @@ async def run_eval(
355363@click .option (
356364 "--num-workers" , help = "The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues." , type = int , default = 5
357365)
358- def run_eval_command (use_existing_preds , dataset , length , instance_id , local , repo , num_workers ):
366+ @click .option ("--model" , help = "The model to use." , type = str , default = "claude-3-7-sonnet-latest" )
367+ def run_eval_command (use_existing_preds , dataset , length , instance_id , local , repo , num_workers , model ):
359368 print (f"Repo: { repo } " )
360- asyncio .run (run_eval (use_existing_preds = use_existing_preds , dataset = dataset , length = length , instance_id = instance_id , codebases = None , local = local , repo = repo , num_workers = num_workers ))
369+ print (f"Model: { model } " )
370+ asyncio .run (run_eval (use_existing_preds = use_existing_preds , dataset = dataset , length = length , instance_id = instance_id , codebases = None , local = local , repo = repo , num_workers = num_workers , model = model ))
361371
362372
363373if __name__ == "__main__" :
0 commit comments