Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions src/vectorcode/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
import contextlib
import hashlib
import json
import logging
import os
import socket
import subprocess
import sys
import traceback
from asyncio.subprocess import Process
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional
Expand Down Expand Up @@ -48,19 +46,16 @@ async def get_collections(


async def try_server(base_url: str):
openapi_url = f"{base_url}/openapi.json"
try:
async with httpx.AsyncClient() as client:
response = await client.get(url=openapi_url)
logger.debug(f"Fetching openapi.json from {openapi_url}: {response=}")
if response.status_code != 200:
return False
openapi_json = json.loads(response.content.decode())
if openapi_json:
return openapi_json.get("info", {}).get("title", "").lower() == "chroma"
except Exception as e:
logger.info(f"Failed to connect to chromadb at {base_url}")
logger.debug(traceback.format_exception(e))
for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb.
heartbeat_url = f"{base_url}/api/{ver}/heartbeat"
try:
async with httpx.AsyncClient() as client:
response = await client.get(url=heartbeat_url)
logger.debug(f"Heartbeat {heartbeat_url} returned {response=}")
if response.status_code == 200:
return True
except (httpx.ConnectError, httpx.ConnectTimeout):
pass
return False


Expand Down
8 changes: 7 additions & 1 deletion src/vectorcode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
import traceback

import httpx

from vectorcode import __version__
from vectorcode.cli_utils import (
CliAction,
Expand Down Expand Up @@ -100,8 +102,12 @@ async def async_main():
from vectorcode.subcommands import files

return_val = await files(final_configs)
except Exception:
except Exception as e:
return_val = 1
if isinstance(e, httpx.RemoteProtocolError): # pragma: nocover
e.add_note(
f"Please verify that {final_configs.db_url} is a working chromadb server."
)
logger.error(traceback.format_exc())
finally:
await ClientManager().kill_servers()
Expand Down
70 changes: 53 additions & 17 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,59 @@ def test_get_embedding_function_init_exception():
)


@pytest.mark.asyncio
async def test_try_server_versions():
# Test successful v1 response
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
assert await try_server("http://localhost:8300") is True
mock_client.return_value.__aenter__.return_value.get.assert_called_once_with(
url="http://localhost:8300/api/v1/heartbeat"
)

# Test fallback to v2 when v1 fails
with patch("httpx.AsyncClient") as mock_client:
mock_response_v1 = MagicMock()
mock_response_v1.status_code = 404
mock_response_v2 = MagicMock()
mock_response_v2.status_code = 200
mock_client.return_value.__aenter__.return_value.get.side_effect = [
mock_response_v1,
mock_response_v2,
]
assert await try_server("http://localhost:8300") is True
assert mock_client.return_value.__aenter__.return_value.get.call_count == 2

# Test both versions fail
with patch("httpx.AsyncClient") as mock_client:
mock_response_v1 = MagicMock()
mock_response_v1.status_code = 404
mock_response_v2 = MagicMock()
mock_response_v2.status_code = 500
mock_client.return_value.__aenter__.return_value.get.side_effect = [
mock_response_v1,
mock_response_v2,
]
assert await try_server("http://localhost:8300") is False

# Test connection error cases
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = (
httpx.ConnectError("Cannot connect")
)
assert await try_server("http://localhost:8300") is False

with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = (
httpx.ConnectTimeout("Connection timeout")
)
assert await try_server("http://localhost:8300") is False


def test_verify_ef():
# Mocking AsyncCollection and Config
mock_collection = MagicMock()
Expand Down Expand Up @@ -137,18 +190,10 @@ async def test_try_server_mocked(mock_socket):
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b'{"info":{"title": "Chroma"}}'
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
assert await try_server("http://localhost:8000") is True
with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 404
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
assert await try_server("http://localhost:8000") is False

# Mocking httpx.AsyncClient to raise a ConnectError
with patch("httpx.AsyncClient") as mock_client:
Expand All @@ -157,15 +202,6 @@ async def test_try_server_mocked(mock_socket):
)
assert await try_server("http://localhost:8000") is False

with patch("httpx.AsyncClient") as mock_client:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.content = b'{"info":{"title": "Dummy"}}'
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
assert await try_server("http://localhost:8000") is False

# Mocking httpx.AsyncClient to raise a ConnectTimeout
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = (
Expand Down
Loading