From b2e8c263362b710c0ee9606bcc0c3a32ce98c23a Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Tue, 8 Jul 2025 14:29:59 +0800 Subject: [PATCH] fix(cli): Verify chromadb connection by checking openapi title --- src/vectorcode/common.py | 25 ++++++++------ tests/test_common.py | 70 ++++++++++------------------------------ 2 files changed, 32 insertions(+), 63 deletions(-) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index c5f7cee4..0c0cc536 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -1,11 +1,13 @@ import asyncio import contextlib import hashlib +import json import logging import os import socket import subprocess import sys +import traceback from asyncio.subprocess import Process from dataclasses import dataclass from typing import Any, AsyncGenerator, Optional @@ -46,16 +48,19 @@ async def get_collections( async def try_server(base_url: str): - for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb. - heartbeat_url = f"{base_url}/api/{ver}/heartbeat" - try: - async with httpx.AsyncClient() as client: - response = await client.get(url=heartbeat_url) - logger.debug(f"Heartbeat {heartbeat_url} returned {response=}") - if response.status_code == 200: - return True - except (httpx.ConnectError, httpx.ConnectTimeout): - pass + openapi_url = f"{base_url}/openapi.json" + try: + async with httpx.AsyncClient() as client: + response = await client.get(url=openapi_url) + logger.debug(f"Fetching openapi.json from {openapi_url}: {response=}") + if response.status_code != 200: + return False + openapi_json = json.loads(response.content.decode()) + if openapi_json: + return openapi_json.get("info", {}).get("title", "").lower() == "chroma" + except Exception as e: + logger.info(f"Failed to connect to chromadb at {base_url}") + logger.debug(traceback.format_exception(e)) return False diff --git a/tests/test_common.py b/tests/test_common.py index c0dbdc5f..c58b581d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -97,59 +97,6 @@ def test_get_embedding_function_init_exception(): ) -@pytest.mark.asyncio -async def test_try_server_versions(): - # Test successful v1 response - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.return_value = ( - mock_response - ) - assert await try_server("http://localhost:8300") is True - mock_client.return_value.__aenter__.return_value.get.assert_called_once_with( - url="http://localhost:8300/api/v1/heartbeat" - ) - - # Test fallback to v2 when v1 fails - with patch("httpx.AsyncClient") as mock_client: - mock_response_v1 = MagicMock() - mock_response_v1.status_code = 404 - mock_response_v2 = MagicMock() - mock_response_v2.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.side_effect = [ - mock_response_v1, - mock_response_v2, - ] - assert await try_server("http://localhost:8300") is True - assert mock_client.return_value.__aenter__.return_value.get.call_count == 2 - - # Test both versions fail - with patch("httpx.AsyncClient") as mock_client: - mock_response_v1 = MagicMock() - mock_response_v1.status_code = 404 - mock_response_v2 = MagicMock() - mock_response_v2.status_code = 500 - mock_client.return_value.__aenter__.return_value.get.side_effect = [ - mock_response_v1, - mock_response_v2, - ] - assert await try_server("http://localhost:8300") is False - - # Test connection error cases - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectError("Cannot connect") - ) - assert await try_server("http://localhost:8300") is False - - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectTimeout("Connection timeout") - ) - assert await try_server("http://localhost:8300") is False - - def test_verify_ef(): # Mocking AsyncCollection and Config mock_collection = MagicMock() @@ -190,10 +137,18 @@ async def test_try_server_mocked(mock_socket): with patch("httpx.AsyncClient") as mock_client: mock_response = MagicMock() mock_response.status_code = 200 + mock_response.content = b'{"info":{"title": "Chroma"}}' mock_client.return_value.__aenter__.return_value.get.return_value = ( mock_response ) assert await try_server("http://localhost:8000") is True + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_client.return_value.__aenter__.return_value.get.return_value = ( + mock_response + ) + assert await try_server("http://localhost:8000") is False # Mocking httpx.AsyncClient to raise a ConnectError with patch("httpx.AsyncClient") as mock_client: @@ -202,6 +157,15 @@ async def test_try_server_mocked(mock_socket): ) assert await try_server("http://localhost:8000") is False + with patch("httpx.AsyncClient") as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'{"info":{"title": "Dummy"}}' + mock_client.return_value.__aenter__.return_value.get.return_value = ( + mock_response + ) + assert await try_server("http://localhost:8000") is False + # Mocking httpx.AsyncClient to raise a ConnectTimeout with patch("httpx.AsyncClient") as mock_client: mock_client.return_value.__aenter__.return_value.get.side_effect = (