diff --git a/doc/VectorCode-cli.txt b/doc/VectorCode-cli.txt index c56228db..08d58b1d 100644 --- a/doc/VectorCode-cli.txt +++ b/doc/VectorCode-cli.txt @@ -781,7 +781,8 @@ features: - `ls`list local collections, similar to the `ls` subcommand in the CLI; - `query`query from a given collection, similar to the `query` subcommand in - the CLI. + the CLI; +- `vectorise`vectorise files into a given project. To try it out, install the `vectorcode[mcp]` dependency group and the MCP server is available in the shell as `vectorcode-mcp-server`, and make sure diff --git a/docs/cli.md b/docs/cli.md index 49e7e1ba..688c4f8a 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -706,7 +706,8 @@ features: - `ls`: list local collections, similar to the `ls` subcommand in the CLI; - `query`: query from a given collection, similar to the `query` subcommand in - the CLI. + the CLI; +- `vectorise`: vectorise files into a given project. To try it out, install the `vectorcode[mcp]` dependency group and the MCP server is available in the shell as `vectorcode-mcp-server`, and make sure you're using diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index da97a69a..e9a9812c 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -12,6 +12,13 @@ from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.errors import InvalidCollectionException +from vectorcode.subcommands.vectorise import ( + VectoriseStats, + chunked_add, + exclude_paths_by_spec, + find_exclude_specs, +) + try: # pragma: nocover from mcp import ErrorData, McpError from mcp.server.fastmcp import FastMCP @@ -26,6 +33,7 @@ Config, cleanup_path, config_logging, + expand_globs, find_project_config_dir, get_project_config, load_config_file, @@ -89,6 +97,69 @@ async def list_collections() -> list[str]: return names +async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]: + logger.info( + f"vectorise tool called with the following args: {paths=}, {project_root=}" + ) + project_root = os.path.expanduser(project_root) + if not os.path.isdir(project_root): + logger.error(f"Invalid project root: {project_root}") + raise McpError( + ErrorData(code=1, message=f"{project_root} is not a valid path.") + ) + config = await get_project_config(project_root) + try: + client = await get_client(config) + collection = await get_collection(client, config, True) + except Exception as e: + logger.error("Failed to access collection at %s", project_root) + raise McpError( + ErrorData( + code=1, + message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.", + ) + ) + if collection is None: # pragma: nocover + raise McpError( + ErrorData( + code=1, + message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", + ) + ) + + paths = [os.path.expanduser(i) for i in await expand_globs(paths)] + final_config = await config.merge_from( + Config(files=[i for i in paths if os.path.isfile(i)], project_root=project_root) + ) + for ignore_spec in find_exclude_specs(final_config): + if os.path.isfile(ignore_spec): + logger.info(f"Loading ignore specs from {ignore_spec}.") + paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec) + stats = VectoriseStats() + collection_lock = asyncio.Lock() + stats_lock = asyncio.Lock() + max_batch_size = await client.get_max_batch_size() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + final_config, + max_batch_size, + semaphore, + ) + ) + for file in paths + ] + for i, task in enumerate(asyncio.as_completed(tasks), start=1): + await task + return stats.to_dict() + + async def query_tool( n_query: int, query_messages: list[str], project_root: str ) -> list[str]: @@ -186,18 +257,25 @@ async def mcp_server(): mcp.add_tool( fn=list_collections, name="ls", - description="List all projects indexed by VectorCode. Call this before making queries.", + description="\n".join( + prompt_by_categories["ls"] + prompt_by_categories["general"] + ), ) mcp.add_tool( fn=query_tool, name="query", - description=f""" -Use VectorCode to perform vector similarity search on repositories and return a list of relevant file paths and contents. -Make sure `project_root` is one of the values from the `ls` tool. -Unless the user requested otherwise, start your retrievals by {mcp_config.n_results} files. -The result contains the relative paths for the files and their corresponding contents. -""", + description="\n".join( + prompt_by_categories["query"] + prompt_by_categories["general"] + ), + ) + + mcp.add_tool( + fn=vectorise_files, + name="vectorise", + description="\n".join( + prompt_by_categories["vectorise"] + prompt_by_categories["general"] + ), ) return mcp diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 0341772b..4f4b507a 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import cast +from typing import Any, cast from chromadb import GetResult, Where from chromadb.api.models.AsyncCollection import AsyncCollection @@ -49,12 +49,15 @@ async def get_query_result_files( try: if len(configs.query_exclude): logger.info(f"Excluding {len(configs.query_exclude)} files from the query.") - filter: dict[str, dict] = {"path": {"$nin": configs.query_exclude}} + filter: dict[str, Any] = {"path": {"$nin": configs.query_exclude}} else: filter = {} num_query = configs.n_result if QueryInclude.chunk in configs.include: - filter["start"] = {"$gte": 0} + if filter: + filter = {"$and": [filter.copy(), {"$gte": 0}]} + else: + filter["start"] = {"$gte": 0} else: num_query = await collection.count() if configs.query_multiplier > 0: diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py index 64463fca..a6b17689 100644 --- a/tests/subcommands/query/test_query.py +++ b/tests/subcommands/query/test_query.py @@ -205,6 +205,36 @@ async def test_get_query_result_files_with_query_exclude(mock_collection, mock_c assert kwargs["where"] == {"path": {"$nin": ["/excluded/path.py"]}} +@pytest.mark.asyncio +async def test_get_query_result_chunks_with_query_exclude(mock_collection, mock_config): + # Setup query_exclude + mock_config.query_exclude = ["/excluded/path.py"] + mock_config.include = [QueryInclude.chunk, QueryInclude.path] + + with ( + patch("vectorcode.subcommands.query.expand_path") as mock_expand_path, + patch("vectorcode.subcommands.query.expand_globs") as mock_expand_globs, + patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker, + patch("os.path.isfile", return_value=True), # Add this line to mock isfile + ): + mock_expand_globs.return_value = ["/excluded/path.py"] + mock_expand_path.return_value = "/excluded/path.py" + + mock_reranker_instance = MagicMock() + mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) + MockReranker.return_value = mock_reranker_instance + + # Call the function + await get_query_result_files(mock_collection, mock_config) + + # Check that query was called with the right parameters including the where clause + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert kwargs["where"] == { + "$and": [{"path": {"$nin": ["/excluded/path.py"]}}, {"$gte": 0}] + } + + @pytest.mark.asyncio async def test_get_query_reranker_initialisation_error(mock_collection, mock_config): # Configure to use CrossEncoder reranker diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 004e92dc..a9be2f11 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ +import os +import tempfile from argparse import ArgumentParser -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from mcp import McpError @@ -11,6 +13,7 @@ mcp_server, parse_cli_args, query_tool, + vectorise_files, ) @@ -168,6 +171,125 @@ async def test_query_tool_no_collection(): ) +@pytest.mark.asyncio +async def test_vectorise_tool_invalid_project_root(): + with ( + patch("os.path.isdir", return_value=False), + ): + with pytest.raises(McpError): + await vectorise_files(paths=["foo.bar"], project_root=".") + + +@pytest.mark.asyncio +async def test_vectorise_files_success(): + with tempfile.TemporaryDirectory() as temp_dir: + file_path = f"{temp_dir}/test_file.py" + with open(file_path, "w") as f: + f.write("def func(): pass") + + with ( + patch("os.path.isdir", return_value=True), + patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, + patch("vectorcode.mcp_main.get_client") as mock_get_client, + patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch("vectorcode.subcommands.vectorise.chunked_add"), + patch( + "vectorcode.subcommands.vectorise.hash_file", return_value="test_hash" + ), + ): + mock_config = Config(project_root=temp_dir) + mock_get_project_config.return_value = mock_config + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + mock_collection = AsyncMock() + mock_collection.get.return_value = {"ids": [], "metadatas": []} + mock_get_collection.return_value = mock_collection + mock_client.get_max_batch_size.return_value = 100 + + result = await vectorise_files(paths=[file_path], project_root=temp_dir) + + assert result["add"] == 1 + mock_get_project_config.assert_called_once_with(temp_dir) + mock_get_client.assert_called_once_with(mock_config) + mock_get_collection.assert_called_once_with(mock_client, mock_config, True) + + +@pytest.mark.asyncio +async def test_vectorise_files_collection_access_failure(): + with ( + patch("os.path.isdir", return_value=True), + patch("vectorcode.mcp_main.get_project_config"), + patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")), + patch("vectorcode.mcp_main.get_collection"), + ): + with pytest.raises(McpError) as exc_info: + await vectorise_files(paths=["file.py"], project_root="/valid/path") + + assert exc_info.value.error.code == 1 + assert ( + "Failed to create the collection at /valid/path" + in exc_info.value.error.message + ) + + +@pytest.mark.asyncio +async def test_vectorise_files_with_exclude_spec(): + with tempfile.TemporaryDirectory() as temp_dir: + file1 = f"{temp_dir}/file1.py" + excluded_file = f"{temp_dir}/excluded.py" + exclude_spec_file = f"{temp_dir}/.vectorcode/vectorcode.exclude" + + os.makedirs(f"{temp_dir}/.vectorcode") + with open(file1, "w") as f: + f.write("content1") + with open(excluded_file, "w") as f: + f.write("content_excluded") + + # Create mock file handles for specific file contents + mock_exclude_file_handle = mock_open(read_data="excluded.py").return_value + + def mock_open_side_effect(filename, *args, **kwargs): + if filename == exclude_spec_file: + return mock_exclude_file_handle + # For other files that might be opened, return a generic mock + return MagicMock() + + with ( + patch("os.path.isdir", return_value=True), + patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, + patch("vectorcode.mcp_main.get_client") as mock_get_client, + patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, + patch( + "vectorcode.subcommands.vectorise.hash_file", return_value="test_hash" + ), + # Patch builtins.open with the custom side effect + patch("builtins.open", side_effect=mock_open_side_effect), + # Patch os.path.isfile to control which files "exist" + patch( + "os.path.isfile", + side_effect=lambda x: x in [file1, excluded_file, exclude_spec_file], + ), + ): + mock_config = Config(project_root=temp_dir) + mock_get_project_config.return_value = mock_config + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + mock_collection = AsyncMock() + mock_collection.get.return_value = {"ids": [], "metadatas": []} + mock_get_collection.return_value = mock_collection + mock_client.get_max_batch_size.return_value = 100 + + result = await vectorise_files( + paths=[file1, excluded_file], project_root=temp_dir + ) + + assert result["add"] == 0 + assert mock_chunked_add.call_count == 0 + call_args = [call[0][0] for call in mock_chunked_add.call_args_list] + assert excluded_file not in call_args + + @pytest.mark.asyncio async def test_mcp_server(): with ( @@ -188,7 +310,7 @@ async def test_mcp_server(): await mcp_server() - assert mock_add_tool.call_count == 2 + assert mock_add_tool.call_count == 3 @pytest.mark.asyncio @@ -223,7 +345,7 @@ async def new_get_collections(clients): await mcp_server() - assert mock_add_tool.call_count == 2 + assert mock_add_tool.call_count == 3 mock_get_collections.assert_called()