Skip to content

Commit 46c1b14

Browse files
authored
feat: RunCodemodTool (#557)
1 parent 72e9991 commit 46c1b14

File tree

4 files changed

+125
-1
lines changed

4 files changed

+125
-1
lines changed

src/codegen/extensions/langchain/tools.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from codegen.extensions.tools.link_annotation import add_links_to_message
2121
from codegen.extensions.tools.replacement_edit import replacement_edit
2222
from codegen.extensions.tools.reveal_symbol import reveal_symbol
23+
from codegen.extensions.tools.run_codemod import run_codemod
2324
from codegen.extensions.tools.search import search
2425
from codegen.extensions.tools.semantic_edit import semantic_edit
2526
from codegen.extensions.tools.semantic_search import semantic_search
@@ -709,7 +710,8 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
709710
RenameFileTool(codebase),
710711
ReplacementEditTool(codebase),
711712
RevealSymbolTool(codebase),
712-
RunBashCommandTool(), # Note: This tool doesn't need the codebase
713+
RunBashCommandTool(),
714+
RunCodemodTool(codebase),
713715
SearchTool(codebase),
714716
SemanticEditTool(codebase),
715717
SemanticSearchTool(codebase),
@@ -770,3 +772,38 @@ def _run(
770772
count=count,
771773
)
772774
return json.dumps(result, indent=2)
775+
776+
777+
class RunCodemodInput(BaseModel):
778+
"""Input for running a codemod."""
779+
780+
codemod_source: str = Field(
781+
...,
782+
description="""Source code of the codemod function. Must define a 'run(codebase: Codebase)' function that makes the desired changes.
783+
Example:
784+
```python
785+
def run(codebase: Codebase):
786+
for file in codebase.files:
787+
if file.filepath.endswith('.py'):
788+
content = file.content
789+
# Make changes to content
790+
file.edit(new_content)
791+
```
792+
""",
793+
)
794+
795+
796+
class RunCodemodTool(BaseTool):
797+
"""Tool for running custom codemod functions."""
798+
799+
name: ClassVar[str] = "run_codemod"
800+
description: ClassVar[str] = "Run a custom codemod function to make systematic changes across the codebase"
801+
args_schema: ClassVar[type[BaseModel]] = RunCodemodInput
802+
codebase: Codebase = Field(exclude=True)
803+
804+
def __init__(self, codebase: Codebase) -> None:
805+
super().__init__(codebase=codebase)
806+
807+
def _run(self, codemod_source: str) -> str:
808+
result = run_codemod(self.codebase, codemod_source)
809+
return json.dumps(result, indent=2)

src/codegen/extensions/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .rename_file import rename_file
2020
from .replacement_edit import replacement_edit
2121
from .reveal_symbol import reveal_symbol
22+
from .run_codemod import run_codemod
2223
from .search import search
2324
from .semantic_edit import semantic_edit
2425
from .semantic_search import semantic_search
@@ -45,6 +46,7 @@
4546
"rename_file",
4647
"replacement_edit",
4748
"reveal_symbol",
49+
"run_codemod",
4850
# Search operations
4951
"search",
5052
# Edit operations
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Tool for running custom codemod functions on the codebase."""
2+
3+
import importlib.util
4+
import sys
5+
from pathlib import Path
6+
from tempfile import NamedTemporaryFile
7+
from typing import Any
8+
9+
from codegen import Codebase
10+
11+
12+
def run_codemod(codebase: Codebase, codemod_source: str) -> dict[str, Any]:
13+
"""Run a custom codemod function on the codebase.
14+
15+
The codemod_source should define a function like:
16+
```python
17+
def run(codebase: Codebase):
18+
# Make changes to the codebase
19+
...
20+
```
21+
22+
Args:
23+
codebase: The codebase to operate on
24+
codemod_source: Source code of the codemod function
25+
26+
Returns:
27+
Dict containing execution results and diffs
28+
29+
Raises:
30+
ValueError: If codemod source is invalid or execution fails
31+
"""
32+
# Create a temporary module to run the codemod
33+
with NamedTemporaryFile(suffix=".py", mode="w", delete=False) as temp_file:
34+
# Add imports and write the codemod source
35+
temp_file.write("from codegen import Codebase\n\n")
36+
temp_file.write(codemod_source)
37+
temp_file.flush()
38+
39+
try:
40+
# Import the temporary module
41+
spec = importlib.util.spec_from_file_location("codemod", temp_file.name)
42+
if not spec or not spec.loader:
43+
msg = "Failed to create module spec"
44+
raise ValueError(msg)
45+
46+
module = importlib.util.module_from_spec(spec)
47+
sys.modules["codemod"] = module
48+
spec.loader.exec_module(module)
49+
50+
# Verify run function exists
51+
if not hasattr(module, "run"):
52+
msg = "Codemod must define a 'run' function"
53+
raise ValueError(msg)
54+
55+
# Run the codemod
56+
module.run(codebase)
57+
codebase.commit()
58+
diff = codebase.get_diff()
59+
60+
return {
61+
"status": "success",
62+
"diff": diff,
63+
}
64+
65+
except Exception as e:
66+
msg = f"Codemod execution failed: {e!s}"
67+
raise ValueError(msg)
68+
69+
finally:
70+
# Clean up temporary file
71+
Path(temp_file.name).unlink()

tests/unit/codegen/extensions/test_tools.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
rename_file,
1515
replacement_edit,
1616
reveal_symbol,
17+
run_codemod,
1718
search,
1819
semantic_edit,
1920
semantic_search,
@@ -240,3 +241,16 @@ def test_replacement_edit(codebase):
240241
)
241242
assert result["status"] == "unchanged"
242243
assert "No matches found" in result["message"]
244+
245+
246+
def test_run_codemod(codebase):
247+
"""Test running custom codemods."""
248+
# Test adding type hints
249+
codemod_source = """
250+
def run(codebase: Codebase):
251+
for file in codebase.files:
252+
file.edit('# hello, world!' + file.content)
253+
"""
254+
result = run_codemod(codebase, codemod_source)
255+
assert result["status"] == "success"
256+
assert "+# hello, world" in result["diff"]

0 commit comments

Comments
 (0)