Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,9 @@ The JSON configuration file may hold the following values:
Then the embedding function object will be initialised as
`OllamaEmbeddingFunction(url="http://127.0.0.1:11434/api/embeddings",
model_name="nomic-embed-text")`. Default: `{}`;
- `host` and `port`: string and integer, Chromadb server host and port. VectorCode will start an
- `db_url`: string, the url that points to the Chromadb server. VectorCode will start an
HTTP server for Chromadb at a randomly picked free port on `localhost` if your
configured `host:port` is not accessible. This allows the use of `AsyncHttpClient`.
Default: `127.0.0.1:8000`;
configured `http://host:port` is not accessible. Default: `http://127.0.0.1:8000`;
- `db_path`: string, Path to local persistent database. This is where the files for
your database will be stored. Default: `~/.local/share/vectorcode/chromadb/`;
- `db_log_path`: string, path to the _directory_ where the built-in chromadb
Expand Down
23 changes: 17 additions & 6 deletions src/vectorcode/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ class Config:
files: list[PathLike] = field(default_factory=list)
project_root: Optional[PathLike] = None
query: Optional[list[str]] = None
host: str = "127.0.0.1"
port: int = 8000
db_url: str = "http://127.0.0.1:8000"
embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is.
embedding_params: dict[str, Any] = field(default_factory=(lambda: {}))
n_result: int = 1
Expand Down Expand Up @@ -105,8 +104,21 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
"""
default_config = Config()
db_path = config_dict.get("db_path")
host = config_dict.get("host") or "localhost"
port = config_dict.get("port") or 8000
db_url = config_dict.get("db_url")
if db_url is None:
host = config_dict.get("host")
port = config_dict.get("port")
if host is not None or port is not None:
# TODO: deprecate `host` and `port` in 0.7.0
host = host or "127.0.0.1"
port = port or 8000
db_url = f"http://{host}:{port}"
logger.warning(
f'"host" and "port" are deprecated and will be removed in 0.7.0. Use "db_url" (eg. {db_url}).'
)
else:
db_url = "http://127.0.0.1:8000"

if db_path is None:
db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/")
elif not os.path.isdir(db_path):
Expand All @@ -121,8 +133,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
"embedding_params": config_dict.get(
"embedding_params", default_config.embedding_params
),
"host": host,
"port": port,
"db_url": db_url,
"db_path": db_path,
"db_log_path": os.path.expanduser(
config_dict.get("db_log_path", default_config.db_log_path)
Expand Down
47 changes: 25 additions & 22 deletions src/vectorcode/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import socket
import subprocess
import sys
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
from urllib.parse import urlparse

import chromadb
import httpx
from chromadb.api import AsyncClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.config import Settings
from chromadb.config import APIVersion, Settings
from chromadb.utils import embedding_functions

from vectorcode.cli_utils import Config, expand_path
Expand Down Expand Up @@ -40,26 +41,26 @@ async def get_collections(
yield collection


async def try_server(host: str, port: int):
async def try_server(base_url: str):
for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb.
url = f"http://{host}:{port}/api/{ver}/heartbeat"
heartbeat_url = f"{base_url}/api/{ver}/heartbeat"
try:
async with httpx.AsyncClient() as client:
response = await client.get(url=url)
logger.debug(f"Heartbeat {url} returned {response=}")
response = await client.get(url=heartbeat_url)
logger.debug(f"Heartbeat {heartbeat_url} returned {response=}")
if response.status_code == 200:
return True
except (httpx.ConnectError, httpx.ConnectTimeout):
pass
return False


async def wait_for_server(host, port, timeout=10):
async def wait_for_server(url: str, timeout=10):
# Poll the server until it's ready or timeout is reached

start_time = asyncio.get_event_loop().time()
while True:
if await try_server(host, port):
if await try_server(url):
return

if asyncio.get_event_loop().time() - start_time > timeout:
Expand All @@ -82,10 +83,8 @@ async def start_server(configs: Config):
env = os.environ.copy()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # OS selects a free ephemeral port
configs.port = int(s.getsockname()[1])
logger.warning(
f"Starting bundled ChromaDB server at {configs.host}:{configs.port}."
)
port = int(s.getsockname()[1])
logger.warning(f"Starting bundled ChromaDB server at http://127.0.0.1:{port}.")
env.update({"ANONYMIZED_TELEMETRY": "False"})
process = await asyncio.create_subprocess_exec(
sys.executable,
Expand All @@ -95,7 +94,7 @@ async def start_server(configs: Config):
"--host",
"localhost",
"--port",
str(configs.port),
str(port),
"--path",
db_path,
"--log-path",
Expand All @@ -105,28 +104,32 @@ async def start_server(configs: Config):
env=env,
)

await wait_for_server(configs.host, configs.port)
await wait_for_server(f"http://127.0.0.1:{port}")
return process


__CLIENT_CACHE: dict[tuple[str, int], AsyncClientAPI] = {}
__CLIENT_CACHE: dict[str, AsyncClientAPI] = {}


async def get_client(configs: Config) -> AsyncClientAPI:
assert configs.host is not None
assert configs.port is not None
client_entry = (configs.host, configs.port)
client_entry = configs.db_url
if __CLIENT_CACHE.get(client_entry) is None:
settings = {"anonymized_telemetry": False}
settings: dict[str, Any] = {"anonymized_telemetry": False}
if isinstance(configs.db_settings, dict):
valid_settings = {
k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
}
settings.update(valid_settings)
parsed_url = urlparse(configs.db_url)
settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
settings["chroma_server_http_port"] = parsed_url.port or 8000
settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
settings_obj = Settings(**settings)
__CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient(
host=configs.host or "localhost",
port=configs.port or 8000,
settings=Settings(**settings),
settings=settings_obj,
host=str(settings_obj.chroma_server_host),
port=int(settings_obj.chroma_server_http_port or 8000),
)
return __CLIENT_CACHE[client_entry]

Expand Down
2 changes: 1 addition & 1 deletion src/vectorcode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def async_main():
from vectorcode.common import start_server, try_server

server_process = None
if not await try_server(final_configs.host, final_configs.port):
if not await try_server(final_configs.db_url):
server_process = await start_server(final_configs)

if final_configs.pipe:
Expand Down
21 changes: 7 additions & 14 deletions tests/subcommands/test_vectorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,7 @@ def test_load_files_from_include_no_files(mock_check_tree_files, mock_isfile, tm
@pytest.mark.asyncio
async def test_vectorise(capsys):
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down Expand Up @@ -330,8 +329,7 @@ async def test_vectorise(capsys):
@pytest.mark.asyncio
async def test_vectorise_cancelled():
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down Expand Up @@ -373,8 +371,7 @@ async def mock_chunked_add(*args, **kwargs):
@pytest.mark.asyncio
async def test_vectorise_orphaned_files():
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down Expand Up @@ -443,8 +440,7 @@ def is_file_side_effect(path):
@pytest.mark.asyncio
async def test_vectorise_collection_index_error():
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand All @@ -470,8 +466,7 @@ async def test_vectorise_collection_index_error():
@pytest.mark.asyncio
async def test_vectorise_verify_ef_false():
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down Expand Up @@ -500,8 +495,7 @@ async def test_vectorise_verify_ef_false():
@pytest.mark.asyncio
async def test_vectorise_gitignore():
configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down Expand Up @@ -548,8 +542,7 @@ async def test_vectorise_exclude_file(tmpdir):
exclude_file.write("excluded_file.py\n")

configs = Config(
host="test_host",
port=1234,
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
Expand Down
53 changes: 23 additions & 30 deletions tests/test_cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ async def test_config_import_from():
os.makedirs(db_path, exist_ok=True)
config_dict: Dict[str, Any] = {
"db_path": db_path,
"host": "test_host",
"port": 1234,
"db_url": "http://test_host:1234",
"embedding_function": "TestEmbedding",
"embedding_params": {"param1": "value1"},
"chunk_size": 512,
Expand All @@ -42,8 +41,7 @@ async def test_config_import_from():
config = await Config.import_from(config_dict)
assert config.db_path == db_path
assert config.db_log_path == os.path.expanduser("~/.local/share/vectorcode/")
assert config.host == "test_host"
assert config.port == 1234
assert config.db_url == "http://test_host:1234"
assert config.embedding_function == "TestEmbedding"
assert config.embedding_params == {"param1": "value1"}
assert config.chunk_size == 512
Expand All @@ -54,6 +52,14 @@ async def test_config_import_from():
assert config.db_settings == {"db_setting1": "db_value1"}


@pytest.mark.asyncio
async def test_config_import_from_fallback_host_port():
conf = {"host": "test_host"}
assert (await Config.import_from(conf)).db_url == "http://test_host:8000"
conf = {"port": 114514}
assert (await Config.import_from(conf)).db_url == "http://127.0.0.1:114514"


@pytest.mark.asyncio
async def test_config_import_from_invalid_path():
config_dict: Dict[str, Any] = {"db_path": "/path/does/not/exist"}
Expand All @@ -75,22 +81,20 @@ async def test_config_import_from_db_path_is_file():

@pytest.mark.asyncio
async def test_config_merge_from():
config1 = Config(host="host1", port=8001, n_result=5)
config2 = Config(host="host2", port=None, query=["test"])
config1 = Config(db_url="http://host1:8001", n_result=5)
config2 = Config(db_url="http://host2:8002", query=["test"])
merged_config = await config1.merge_from(config2)
assert merged_config.host == "host2"
assert merged_config.port == 8001 # port from config1 should be retained
assert merged_config.db_url == "http://host2:8002"
assert merged_config.n_result == 5
assert merged_config.query == ["test"]


@pytest.mark.asyncio
async def test_config_merge_from_new_fields():
config1 = Config(host="host1", port=8001)
config1 = Config(db_url="http://host1:8001")
config2 = Config(query=["test"], n_result=10, recursive=True)
merged_config = await config1.merge_from(config2)
assert merged_config.host == "host1"
assert merged_config.port == 8001
assert merged_config.db_url == "http://host1:8001"
assert merged_config.query == ["test"]
assert merged_config.n_result == 10
assert merged_config.recursive
Expand All @@ -104,8 +108,7 @@ async def test_config_import_from_missing_keys():
# Assert that default values are used
assert config.embedding_function == "SentenceTransformerEmbeddingFunction"
assert config.embedding_params == {}
assert config.host == "localhost"
assert config.port == 8000
assert config.db_url == "http://127.0.0.1:8000"
assert config.db_path == os.path.expanduser("~/.local/share/vectorcode/chromadb/")
assert config.chunk_size == 2500
assert config.overlap_ratio == 0.2
Expand Down Expand Up @@ -318,7 +321,7 @@ def test_find_project_root():
async def test_get_project_config_no_local_config():
with tempfile.TemporaryDirectory() as temp_dir:
config = await get_project_config(temp_dir)
assert config.host in {"127.0.0.1", "localhost"}
assert config.chunk_size == Config().chunk_size, "Should load default value."


@pytest.mark.asyncio
Expand Down Expand Up @@ -394,36 +397,28 @@ async def test_parse_cli_args_vectorise_no_files():

@pytest.mark.asyncio
async def test_get_project_config_local_config(tmp_path):
# Create a temporary directory and a .vectorcode subdirectory
project_root = tmp_path / "project"
vectorcode_dir = project_root / ".vectorcode"
vectorcode_dir.mkdir(parents=True)

# Create a config.json file inside .vectorcode with some custom settings
config_file = vectorcode_dir / "config.json"
config_file.write_text('{"host": "test_host", "port": 9999}')
config_file.write_text('{"db_url": "http://test_host:9999" }')

# Call get_project_config and check if it returns the custom settings
config = await get_project_config(project_root)
assert config.host == "test_host"
assert config.port == 9999
assert config.db_url == "http://test_host:9999"


@pytest.mark.asyncio
async def test_get_project_config_local_config_json5(tmp_path):
# Create a temporary directory and a .vectorcode subdirectory
project_root = tmp_path / "project"
vectorcode_dir = project_root / ".vectorcode"
vectorcode_dir.mkdir(parents=True)

# Create a config.json file inside .vectorcode with some custom settings
config_file = vectorcode_dir / "config.json5"
config_file.write_text('{"host": "test_host", "port": 9999}')
config_file.write_text('{"db_url": "http://test_host:9999" }')

# Call get_project_config and check if it returns the custom settings
config = await get_project_config(project_root)
assert config.host == "test_host"
assert config.port == 9999
assert config.db_url == "http://test_host:9999"


def test_find_project_root_file_input(tmp_path):
Expand Down Expand Up @@ -512,11 +507,9 @@ async def test_config_import_from_hnsw():

@pytest.mark.asyncio
async def test_hnsw_config_merge():
config1 = Config(host="host1", port=8001, hnsw={"space": "ip"})
config2 = Config(host="host2", port=None, hnsw={"ef_construction": 200})
config1 = Config(hnsw={"space": "ip"})
config2 = Config(hnsw={"ef_construction": 200})
merged_config = await config1.merge_from(config2)
assert merged_config.host == "host2"
assert merged_config.port == 8001
assert merged_config.hnsw["space"] == "ip"
assert merged_config.hnsw["ef_construction"] == 200

Expand Down
Loading