diff --git a/docs/cli.md b/docs/cli.md index 3609c811..bea7012f 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -324,6 +324,19 @@ The JSON configuration file may hold the following values: "hnsw:construction_ef": 100 } ``` +- `filetype_map`: `dict[str, list[str]]`, a dictionary where keys are + [language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages) + and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html) + that will match file extensions. This allows overriding automatic language + detection and specifying a treesitter parser for certain file types for which the language parser cannot be + correctly identified (e.g., `.phtml` files containing both php and html). + Example configuration: + ```json5 + "filetype_map": { + "php": ["^phtml$"] + } + ``` + - `chunk_filters`: `dict[str, list[str]]`, a dictionary where the keys are [language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages) and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html) diff --git a/src/vectorcode/chunking.py b/src/vectorcode/chunking.py index 906c68fb..0a62b3c8 100644 --- a/src/vectorcode/chunking.py +++ b/src/vectorcode/chunking.py @@ -8,7 +8,7 @@ from typing import Generator, Optional from pygments.lexer import Lexer -from pygments.lexers import guess_lexer_for_filename +from pygments.lexers import get_lexer_for_filename from pygments.util import ClassNotFound from tree_sitter import Node, Point from tree_sitter_language_pack import get_parser @@ -240,7 +240,7 @@ def __chunk_node( @cache def __guess_type(self, path: str, content: str) -> Optional[Lexer]: try: - return guess_lexer_for_filename(path, content) + return get_lexer_for_filename(path, content) except ClassNotFound: return None @@ -279,6 +279,40 @@ def __load_file_lines(self, path: str) -> list[str]: lines = fin.readlines() return lines + + def __get_parser_from_config(self, file_path: str): + """ + Get parser based on filetype_map config. + """ + filetype_map = self.config.filetype_map + if not filetype_map: + logger.debug("filetype_map is empty in config.") + return None + + filename = os.path.basename(file_path) + extension = os.path.splitext(file_path)[1] + if extension.startswith('.'): + extension = extension[1:] + logger.debug(f"Checking filetype map for extension '{extension}' in {filename}") + for _language, patterns in filetype_map.items(): + language = _language.lower() + for pattern in patterns: + try: + if re.search(pattern, extension): + logger.debug(f"'{filename}' extension matches pattern '{pattern}' for language '{language}'. Attempting to load parser.") + parser = get_parser(language) + logger.debug(f"Found parser for language '{language}' from config.") + return parser + except re.error as e: + e.add_note(f"\nInvalid regex pattern '{pattern}' for language '{language}' in filetype_map") + raise + except LookupError as e: + e.add_note(f"\nTreeSitter Parser for language '{language}' not found. Please check your filetype_map config.") + raise + + logger.debug(f"No matching filetype map entry found for {filename}.") + return None + def chunk(self, data: str) -> Generator[Chunk, None, None]: """ data: path to the file @@ -294,21 +328,23 @@ def chunk(self, data: str) -> Generator[Chunk, None, None]: return parser = None language = None - lexer = self.__guess_type(data, content) - if lexer is not None: - lang_names = [lexer.name] - lang_names.extend(lexer.aliases) - for name in lang_names: - try: - parser = get_parser(name.lower()) - if parser is not None: - language = name.lower() - logger.debug( - "Detected %s filetype for treesitter chunking.", language - ) - break - except LookupError: # pragma: nocover - pass + parser = self.__get_parser_from_config(data) + if parser is None: + lexer = self.__guess_type(data, content) + if lexer is not None: + lang_names = [lexer.name] + lang_names.extend(lexer.aliases) + for name in lang_names: + try: + parser = get_parser(name.lower()) + if parser is not None: + language = name.lower() + logger.debug( + "Detected %s filetype for treesitter chunking.", language + ) + break + except LookupError: # pragma: nocover + pass if parser is None: logger.debug( diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index d33f3752..36db2703 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -94,6 +94,7 @@ class Config: ) hnsw: dict[str, str | int] = field(default_factory=dict) chunk_filters: dict[str, list[str]] = field(default_factory=dict) + filetype_map: dict[str, list[str]] = field(default_factory=dict) encoding: str = "utf8" hooks: bool = False @@ -156,6 +157,9 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": "chunk_filters": config_dict.get( "chunk_filters", default_config.chunk_filters ), + "filetype_map": config_dict.get( + "filetype_map", default_config.filetype_map + ), "encoding": config_dict.get("encoding", default_config.encoding), } ) diff --git a/tests/test_chunking.py b/tests/test_chunking.py index d7d5bda2..5154696a 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -223,6 +223,121 @@ def bar(): assert chunks == ['def 测试():\n return "foo"', 'def bar():\n return "bar"'] os.remove(test_file) +def test_treesitter_chunker_javascript(): + """Test TreeSitterChunker with a sample javascript file using tempfile.""" + chunker = TreeSitterChunker(Config(chunk_size=60)) + + test_content = r""" +function foo() { + return "foo"; +} + +function bar() { + return "bar"; +} + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".js") as tmp_file: + tmp_file.write(test_content) + test_file = tmp_file.name + + chunks = list(str(i) for i in chunker.chunk(test_file)) + assert chunks == ['function foo() {\n return "foo";\n}', 'function bar() {\n return "bar";\n}'] + os.remove(test_file) + +def test_treesitter_chunker_javascript_genshi(): + """Test TreeSitterChunker with a sample javascript + genshi file using tempfile. (bypassing lexers via the filetype_map config param)""" + chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"javascript": ["^kid$"]})) + + test_content = r""" +function foo() { + return `foo with ${genshi}`; +} + +function bar() { + return "bar"; +} + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file: + tmp_file.write(test_content) + test_file = tmp_file.name + + chunks = list(str(i) for i in chunker.chunk(test_file)) + assert chunks == ['function foo() {\n return `foo with ${genshi}`;\n}', 'function bar() {\n return "bar";\n}'] + os.remove(test_file) + +def test_treesitter_chunker_parser_from_config_no_parser_found_error(): + """Test TreeSitterChunker filetype_map: should raise an error if no parser is found""" + chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"unknown_parser": ["^kid$"]})) + + test_content = r""" +function foo() { + return `foo with ${genshi}`; +} + +function bar() { + return "bar"; +} + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file: + tmp_file.write(test_content) + test_file = tmp_file.name + + + with pytest.raises(LookupError): + chunks = list(str(i) for i in chunker.chunk(test_file)) + assert chunks == [] + os.remove(test_file) + +def test_treesitter_chunker_parser_from_config_regex_error(): + """Test TreeSitterChunker filetype_map: should raise an error if a regex is invalid""" + chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"javascript": ["\\"]})) + + test_content = r""" +function foo() { + return `foo with ${genshi}`; +} + +function bar() { + return "bar"; +} + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file: + tmp_file.write(test_content) + test_file = tmp_file.name + + + with pytest.raises(Exception): + chunks = list(str(i) for i in chunker.chunk(test_file)) + assert chunks == [] + os.remove(test_file) + +def test_treesitter_chunker_parser_from_config_no_language_match(): + """Test TreeSitterChunker filetype_map: should continue with the lexer parser checks if no language matches a regex""" + chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"php": ["^jsx$"]})) + + test_content = r""" +function foo() { + return "foo"; +} + +function bar() { + return "bar"; +} + """ + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".js") as tmp_file: + tmp_file.write(test_content) + test_file = tmp_file.name + + chunks = list(str(i) for i in chunker.chunk(test_file)) + assert chunks == ['function foo() {\n return "foo";\n}', 'function bar() {\n return "bar";\n}'] + os.remove(test_file) + + def test_treesitter_chunker_filter(): chunker = TreeSitterChunker(