Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/VectorCode-cli.txt
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,8 @@ features:

- `ls`list local collections, similar to the `ls` subcommand in the CLI;
- `query`query from a given collection, similar to the `query` subcommand in
the CLI.
the CLI;
- `vectorise`vectorise files into a given project.

To try it out, install the `vectorcode[mcp]` dependency group and the MCP
server is available in the shell as `vectorcode-mcp-server`, and make sure
Expand Down
3 changes: 2 additions & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ features:

- `ls`: list local collections, similar to the `ls` subcommand in the CLI;
- `query`: query from a given collection, similar to the `query` subcommand in
the CLI.
the CLI;
- `vectorise`: vectorise files into a given project.

To try it out, install the `vectorcode[mcp]` dependency group and the MCP server
is available in the shell as `vectorcode-mcp-server`, and make sure you're using
Expand Down
92 changes: 85 additions & 7 deletions src/vectorcode/mcp_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.errors import InvalidCollectionException

from vectorcode.subcommands.vectorise import (
VectoriseStats,
chunked_add,
exclude_paths_by_spec,
find_exclude_specs,
)

try: # pragma: nocover
from mcp import ErrorData, McpError
from mcp.server.fastmcp import FastMCP
Expand All @@ -26,6 +33,7 @@
Config,
cleanup_path,
config_logging,
expand_globs,
find_project_config_dir,
get_project_config,
load_config_file,
Expand Down Expand Up @@ -89,6 +97,69 @@ async def list_collections() -> list[str]:
return names


async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]:
logger.info(
f"vectorise tool called with the following args: {paths=}, {project_root=}"
)
project_root = os.path.expanduser(project_root)
if not os.path.isdir(project_root):
logger.error(f"Invalid project root: {project_root}")
raise McpError(
ErrorData(code=1, message=f"{project_root} is not a valid path.")
)
config = await get_project_config(project_root)
try:
client = await get_client(config)
collection = await get_collection(client, config, True)
except Exception as e:
logger.error("Failed to access collection at %s", project_root)
raise McpError(
ErrorData(
code=1,
message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.",
)
)
if collection is None: # pragma: nocover
raise McpError(
ErrorData(
code=1,
message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
)
)

paths = [os.path.expanduser(i) for i in await expand_globs(paths)]
final_config = await config.merge_from(
Config(files=[i for i in paths if os.path.isfile(i)], project_root=project_root)
)
for ignore_spec in find_exclude_specs(final_config):
if os.path.isfile(ignore_spec):
logger.info(f"Loading ignore specs from {ignore_spec}.")
paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec)
stats = VectoriseStats()
collection_lock = asyncio.Lock()
stats_lock = asyncio.Lock()
max_batch_size = await client.get_max_batch_size()
semaphore = asyncio.Semaphore(os.cpu_count() or 1)
tasks = [
asyncio.create_task(
chunked_add(
str(file),
collection,
collection_lock,
stats,
stats_lock,
final_config,
max_batch_size,
semaphore,
)
)
for file in paths
]
for i, task in enumerate(asyncio.as_completed(tasks), start=1):
await task
return stats.to_dict()


async def query_tool(
n_query: int, query_messages: list[str], project_root: str
) -> list[str]:
Expand Down Expand Up @@ -186,18 +257,25 @@ async def mcp_server():
mcp.add_tool(
fn=list_collections,
name="ls",
description="List all projects indexed by VectorCode. Call this before making queries.",
description="\n".join(
prompt_by_categories["ls"] + prompt_by_categories["general"]
),
)

mcp.add_tool(
fn=query_tool,
name="query",
description=f"""
Use VectorCode to perform vector similarity search on repositories and return a list of relevant file paths and contents.
Make sure `project_root` is one of the values from the `ls` tool.
Unless the user requested otherwise, start your retrievals by {mcp_config.n_results} files.
The result contains the relative paths for the files and their corresponding contents.
""",
description="\n".join(
prompt_by_categories["query"] + prompt_by_categories["general"]
),
)

mcp.add_tool(
fn=vectorise_files,
name="vectorise",
description="\n".join(
prompt_by_categories["vectorise"] + prompt_by_categories["general"]
),
)

return mcp
Expand Down
9 changes: 6 additions & 3 deletions src/vectorcode/subcommands/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import cast
from typing import Any, cast

from chromadb import GetResult, Where
from chromadb.api.models.AsyncCollection import AsyncCollection
Expand Down Expand Up @@ -49,12 +49,15 @@ async def get_query_result_files(
try:
if len(configs.query_exclude):
logger.info(f"Excluding {len(configs.query_exclude)} files from the query.")
filter: dict[str, dict] = {"path": {"$nin": configs.query_exclude}}
filter: dict[str, Any] = {"path": {"$nin": configs.query_exclude}}
else:
filter = {}
num_query = configs.n_result
if QueryInclude.chunk in configs.include:
filter["start"] = {"$gte": 0}
if filter:
filter = {"$and": [filter.copy(), {"$gte": 0}]}
else:
filter["start"] = {"$gte": 0}
else:
num_query = await collection.count()
if configs.query_multiplier > 0:
Expand Down
30 changes: 30 additions & 0 deletions tests/subcommands/query/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,36 @@ async def test_get_query_result_files_with_query_exclude(mock_collection, mock_c
assert kwargs["where"] == {"path": {"$nin": ["/excluded/path.py"]}}


@pytest.mark.asyncio
async def test_get_query_result_chunks_with_query_exclude(mock_collection, mock_config):
# Setup query_exclude
mock_config.query_exclude = ["/excluded/path.py"]
mock_config.include = [QueryInclude.chunk, QueryInclude.path]

with (
patch("vectorcode.subcommands.query.expand_path") as mock_expand_path,
patch("vectorcode.subcommands.query.expand_globs") as mock_expand_globs,
patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker,
patch("os.path.isfile", return_value=True), # Add this line to mock isfile
):
mock_expand_globs.return_value = ["/excluded/path.py"]
mock_expand_path.return_value = "/excluded/path.py"

mock_reranker_instance = MagicMock()
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
MockReranker.return_value = mock_reranker_instance

# Call the function
await get_query_result_files(mock_collection, mock_config)

# Check that query was called with the right parameters including the where clause
mock_collection.query.assert_called_once()
_, kwargs = mock_collection.query.call_args
assert kwargs["where"] == {
"$and": [{"path": {"$nin": ["/excluded/path.py"]}}, {"$gte": 0}]
}


@pytest.mark.asyncio
async def test_get_query_reranker_initialisation_error(mock_collection, mock_config):
# Configure to use CrossEncoder reranker
Expand Down
128 changes: 125 additions & 3 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import tempfile
from argparse import ArgumentParser
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, mock_open, patch

import pytest
from mcp import McpError
Expand All @@ -11,6 +13,7 @@
mcp_server,
parse_cli_args,
query_tool,
vectorise_files,
)


Expand Down Expand Up @@ -168,6 +171,125 @@ async def test_query_tool_no_collection():
)


@pytest.mark.asyncio
async def test_vectorise_tool_invalid_project_root():
with (
patch("os.path.isdir", return_value=False),
):
with pytest.raises(McpError):
await vectorise_files(paths=["foo.bar"], project_root=".")


@pytest.mark.asyncio
async def test_vectorise_files_success():
with tempfile.TemporaryDirectory() as temp_dir:
file_path = f"{temp_dir}/test_file.py"
with open(file_path, "w") as f:
f.write("def func(): pass")

with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch("vectorcode.subcommands.vectorise.chunked_add"),
patch(
"vectorcode.subcommands.vectorise.hash_file", return_value="test_hash"
),
):
mock_config = Config(project_root=temp_dir)
mock_get_project_config.return_value = mock_config
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
mock_collection = AsyncMock()
mock_collection.get.return_value = {"ids": [], "metadatas": []}
mock_get_collection.return_value = mock_collection
mock_client.get_max_batch_size.return_value = 100

result = await vectorise_files(paths=[file_path], project_root=temp_dir)

assert result["add"] == 1
mock_get_project_config.assert_called_once_with(temp_dir)
mock_get_client.assert_called_once_with(mock_config)
mock_get_collection.assert_called_once_with(mock_client, mock_config, True)


@pytest.mark.asyncio
async def test_vectorise_files_collection_access_failure():
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config"),
patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")),
patch("vectorcode.mcp_main.get_collection"),
):
with pytest.raises(McpError) as exc_info:
await vectorise_files(paths=["file.py"], project_root="/valid/path")

assert exc_info.value.error.code == 1
assert (
"Failed to create the collection at /valid/path"
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_vectorise_files_with_exclude_spec():
with tempfile.TemporaryDirectory() as temp_dir:
file1 = f"{temp_dir}/file1.py"
excluded_file = f"{temp_dir}/excluded.py"
exclude_spec_file = f"{temp_dir}/.vectorcode/vectorcode.exclude"

os.makedirs(f"{temp_dir}/.vectorcode")
with open(file1, "w") as f:
f.write("content1")
with open(excluded_file, "w") as f:
f.write("content_excluded")

# Create mock file handles for specific file contents
mock_exclude_file_handle = mock_open(read_data="excluded.py").return_value

def mock_open_side_effect(filename, *args, **kwargs):
if filename == exclude_spec_file:
return mock_exclude_file_handle
# For other files that might be opened, return a generic mock
return MagicMock()

with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add,
patch(
"vectorcode.subcommands.vectorise.hash_file", return_value="test_hash"
),
# Patch builtins.open with the custom side effect
patch("builtins.open", side_effect=mock_open_side_effect),
# Patch os.path.isfile to control which files "exist"
patch(
"os.path.isfile",
side_effect=lambda x: x in [file1, excluded_file, exclude_spec_file],
),
):
mock_config = Config(project_root=temp_dir)
mock_get_project_config.return_value = mock_config
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
mock_collection = AsyncMock()
mock_collection.get.return_value = {"ids": [], "metadatas": []}
mock_get_collection.return_value = mock_collection
mock_client.get_max_batch_size.return_value = 100

result = await vectorise_files(
paths=[file1, excluded_file], project_root=temp_dir
)

assert result["add"] == 0
assert mock_chunked_add.call_count == 0
call_args = [call[0][0] for call in mock_chunked_add.call_args_list]
assert excluded_file not in call_args


@pytest.mark.asyncio
async def test_mcp_server():
with (
Expand All @@ -188,7 +310,7 @@ async def test_mcp_server():

await mcp_server()

assert mock_add_tool.call_count == 2
assert mock_add_tool.call_count == 3


@pytest.mark.asyncio
Expand Down Expand Up @@ -223,7 +345,7 @@ async def new_get_collections(clients):

await mcp_server()

assert mock_add_tool.call_count == 2
assert mock_add_tool.call_count == 3
mock_get_collections.assert_called()


Expand Down