Skip to content
Merged
13 changes: 13 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,19 @@ The JSON configuration file may hold the following values:
"hnsw:construction_ef": 100
}
```
- `filetype_map`: `dict[str, list[str]]`, a dictionary where keys are
[language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages)
and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html)
that will match file extensions. This allows overriding automatic language
detection and specifying a treesitter parser for certain file types for which the language parser cannot be
correctly identified (e.g., `.phtml` files containing both php and html).
Example configuration:
```json5
"filetype_map": {
"php": ["^phtml$"]
}
```

- `chunk_filters`: `dict[str, list[str]]`, a dictionary where the keys are
[language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages)
and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html)
Expand Down
70 changes: 53 additions & 17 deletions src/vectorcode/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Generator, Optional

from pygments.lexer import Lexer
from pygments.lexers import guess_lexer_for_filename
from pygments.lexers import get_lexer_for_filename
from pygments.util import ClassNotFound
from tree_sitter import Node, Point
from tree_sitter_language_pack import get_parser
Expand Down Expand Up @@ -240,7 +240,7 @@ def __chunk_node(
@cache
def __guess_type(self, path: str, content: str) -> Optional[Lexer]:
try:
return guess_lexer_for_filename(path, content)
return get_lexer_for_filename(path, content)

except ClassNotFound:
return None
Expand Down Expand Up @@ -279,6 +279,40 @@ def __load_file_lines(self, path: str) -> list[str]:
lines = fin.readlines()
return lines


def __get_parser_from_config(self, file_path: str):
"""
Get parser based on filetype_map config.
"""
filetype_map = self.config.filetype_map
if not filetype_map:
logger.debug("filetype_map is empty in config.")
return None

filename = os.path.basename(file_path)
extension = os.path.splitext(file_path)[1]
if extension.startswith('.'):
extension = extension[1:]
logger.debug(f"Checking filetype map for extension '{extension}' in {filename}")
for _language, patterns in filetype_map.items():
language = _language.lower()
for pattern in patterns:
try:
if re.search(pattern, extension):
logger.debug(f"'{filename}' extension matches pattern '{pattern}' for language '{language}'. Attempting to load parser.")
parser = get_parser(language)
logger.debug(f"Found parser for language '{language}' from config.")
return parser
except re.error as e:
e.add_note(f"\nInvalid regex pattern '{pattern}' for language '{language}' in filetype_map")
raise
except LookupError as e:
e.add_note(f"\nTreeSitter Parser for language '{language}' not found. Please check your filetype_map config.")
raise

logger.debug(f"No matching filetype map entry found for {filename}.")
return None

def chunk(self, data: str) -> Generator[Chunk, None, None]:
"""
data: path to the file
Expand All @@ -294,21 +328,23 @@ def chunk(self, data: str) -> Generator[Chunk, None, None]:
return
parser = None
language = None
lexer = self.__guess_type(data, content)
if lexer is not None:
lang_names = [lexer.name]
lang_names.extend(lexer.aliases)
for name in lang_names:
try:
parser = get_parser(name.lower())
if parser is not None:
language = name.lower()
logger.debug(
"Detected %s filetype for treesitter chunking.", language
)
break
except LookupError: # pragma: nocover
pass
parser = self.__get_parser_from_config(data)
if parser is None:
lexer = self.__guess_type(data, content)
if lexer is not None:
lang_names = [lexer.name]
lang_names.extend(lexer.aliases)
for name in lang_names:
try:
parser = get_parser(name.lower())
if parser is not None:
language = name.lower()
logger.debug(
"Detected %s filetype for treesitter chunking.", language
)
break
except LookupError: # pragma: nocover
pass

if parser is None:
logger.debug(
Expand Down
4 changes: 4 additions & 0 deletions src/vectorcode/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Config:
)
hnsw: dict[str, str | int] = field(default_factory=dict)
chunk_filters: dict[str, list[str]] = field(default_factory=dict)
filetype_map: dict[str, list[str]] = field(default_factory=dict)
encoding: str = "utf8"
hooks: bool = False

Expand Down Expand Up @@ -156,6 +157,9 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
"chunk_filters": config_dict.get(
"chunk_filters", default_config.chunk_filters
),
"filetype_map": config_dict.get(
"filetype_map", default_config.filetype_map
),
"encoding": config_dict.get("encoding", default_config.encoding),
}
)
Expand Down
115 changes: 115 additions & 0 deletions tests/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,121 @@ def bar():
assert chunks == ['def 测试():\n return "foo"', 'def bar():\n return "bar"']
os.remove(test_file)

def test_treesitter_chunker_javascript():
"""Test TreeSitterChunker with a sample javascript file using tempfile."""
chunker = TreeSitterChunker(Config(chunk_size=60))

test_content = r"""
function foo() {
return "foo";
}

function bar() {
return "bar";
}
"""

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".js") as tmp_file:
tmp_file.write(test_content)
test_file = tmp_file.name

chunks = list(str(i) for i in chunker.chunk(test_file))
assert chunks == ['function foo() {\n return "foo";\n}', 'function bar() {\n return "bar";\n}']
os.remove(test_file)

def test_treesitter_chunker_javascript_genshi():
"""Test TreeSitterChunker with a sample javascript + genshi file using tempfile. (bypassing lexers via the filetype_map config param)"""
chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"javascript": ["^kid$"]}))

test_content = r"""
function foo() {
return `foo with ${genshi}`;
}

function bar() {
return "bar";
}
"""

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file:
tmp_file.write(test_content)
test_file = tmp_file.name

chunks = list(str(i) for i in chunker.chunk(test_file))
assert chunks == ['function foo() {\n return `foo with ${genshi}`;\n}', 'function bar() {\n return "bar";\n}']
os.remove(test_file)

def test_treesitter_chunker_parser_from_config_no_parser_found_error():
"""Test TreeSitterChunker filetype_map: should raise an error if no parser is found"""
chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"unknown_parser": ["^kid$"]}))

test_content = r"""
function foo() {
return `foo with ${genshi}`;
}

function bar() {
return "bar";
}
"""

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file:
tmp_file.write(test_content)
test_file = tmp_file.name


with pytest.raises(LookupError):
chunks = list(str(i) for i in chunker.chunk(test_file))
assert chunks == []
os.remove(test_file)

def test_treesitter_chunker_parser_from_config_regex_error():
"""Test TreeSitterChunker filetype_map: should raise an error if a regex is invalid"""
chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"javascript": ["\\"]}))

test_content = r"""
function foo() {
return `foo with ${genshi}`;
}

function bar() {
return "bar";
}
"""

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".kid") as tmp_file:
tmp_file.write(test_content)
test_file = tmp_file.name


with pytest.raises(Exception):
chunks = list(str(i) for i in chunker.chunk(test_file))
assert chunks == []
os.remove(test_file)

def test_treesitter_chunker_parser_from_config_no_language_match():
"""Test TreeSitterChunker filetype_map: should continue with the lexer parser checks if no language matches a regex"""
chunker = TreeSitterChunker(Config(chunk_size=60, filetype_map={"php": ["^jsx$"]}))

test_content = r"""
function foo() {
return "foo";
}

function bar() {
return "bar";
}
"""

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".js") as tmp_file:
tmp_file.write(test_content)
test_file = tmp_file.name

chunks = list(str(i) for i in chunker.chunk(test_file))
assert chunks == ['function foo() {\n return "foo";\n}', 'function bar() {\n return "bar";\n}']
os.remove(test_file)



def test_treesitter_chunker_filter():
chunker = TreeSitterChunker(
Expand Down