diff --git a/src/vectorcode/chunking.py b/src/vectorcode/chunking.py index dbcd0a7e..906c68fb 100644 --- a/src/vectorcode/chunking.py +++ b/src/vectorcode/chunking.py @@ -141,6 +141,7 @@ def __init__(self, config: Optional[Config] = None): if config is None: config = Config() super().__init__(config) + self._fallback_chunker = StringChunker(config) def __chunk_node( self, node: Node, text_bytes: bytes @@ -153,6 +154,12 @@ def __chunk_node( prev_node = None current_start = None + logger.debug("nbr children: %s", len(node.children)) + # if node has no children we fallback to the string chunker + if len(node.children) == 0 and node.text: + logger.debug("No children, falling back to string chunker") + yield from self._fallback_chunker.chunk(node.text.decode()) + for child in node.children: child_bytes = text_bytes[child.start_byte : child.end_byte] child_text = child_bytes.decode() @@ -307,7 +314,7 @@ def chunk(self, data: str) -> Generator[Chunk, None, None]: logger.debug( "Unable to pick a suitable parser. Fall back to naive chunking" ) - yield from StringChunker(self.config).chunk(content) + yield from self._fallback_chunker.chunk(content) else: pattern_str = self.__build_pattern(language=language) content_bytes = content.encode() diff --git a/tests/test_chunking.py b/tests/test_chunking.py index 094b7b02..d7d5bda2 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -1,5 +1,6 @@ import os import tempfile +from unittest.mock import MagicMock import pytest from tree_sitter import Point @@ -159,6 +160,24 @@ def bar(): os.remove(test_file) +def test_treesitter_chunker_fallback_on_long_node(): + test_content = r""" +def foo(): + return "a very very very very very long string" + """ + config = Config(chunk_size=15) + with ( + tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as temp_py_file, + ): + temp_py_file.write(test_content) + ts_chunker = TreeSitterChunker(config) + ts_chunker._fallback_chunker.chunk = MagicMock() + list(ts_chunker.chunk(temp_py_file.name)) + ts_chunker._fallback_chunker.chunk.assert_called_once() + + def test_treesitter_chunker_python_encoding(): """Test TreeSitterChunker with a sample file using tempfile.""" chunker = TreeSitterChunker(Config(chunk_size=30, encoding="gbk"))