diff --git a/openviking/async_client.py b/openviking/async_client.py index b4f48ed5..8cffeedb 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -7,7 +7,7 @@ """ import threading -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from openviking.client import LocalClient, Session from openviking.service.debug_service import SystemStatus @@ -160,10 +160,23 @@ async def add_resource( timeout=timeout, ) - async def wait_processed(self, timeout: float = None) -> Dict[str, Any]: - """Wait for all queued processing to complete.""" + async def wait_processed( + self, + timeout: float = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> Dict[str, Any]: + """Wait for all queued processing to complete. + + Args: + timeout: Wait timeout in seconds. + progress_callback: Optional callback invoked each poll iteration with + queue status dicts containing ``pending``, ``in_progress``, + ``processed``, ``error_count``, and ``total`` fields. + """ await self._ensure_initialized() - return await self._client.wait_processed(timeout=timeout) + return await self._client.wait_processed( + timeout=timeout, progress_callback=progress_callback + ) async def add_skill( self, diff --git a/openviking/client/local.py b/openviking/client/local.py index 46acca99..ab274602 100644 --- a/openviking/client/local.py +++ b/openviking/client/local.py @@ -5,7 +5,7 @@ Implements BaseClient interface using direct service calls (embedded mode). """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from openviking.service import OpenVikingService from openviking_cli.client.base import BaseClient @@ -82,9 +82,15 @@ async def add_skill( timeout=timeout, ) - async def wait_processed(self, timeout: Optional[float] = None) -> Dict[str, Any]: + async def wait_processed( + self, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> Dict[str, Any]: """Wait for all processing to complete.""" - return await self._service.resources.wait_processed(timeout=timeout) + return await self._service.resources.wait_processed( + timeout=timeout, progress_callback=progress_callback + ) # ============= File System ============= diff --git a/openviking/service/resource_service.py b/openviking/service/resource_service.py index dd6e5dcb..6703167f 100644 --- a/openviking/service/resource_service.py +++ b/openviking/service/resource_service.py @@ -6,7 +6,7 @@ Provides resource management operations: add_resource, add_skill, wait_processed. """ -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional from openviking.storage import VikingDBManager from openviking.storage.queuefs import get_queue_manager @@ -164,18 +164,42 @@ async def add_skill( return result - async def wait_processed(self, timeout: Optional[float] = None) -> Dict[str, Any]: + async def wait_processed( + self, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> Dict[str, Any]: """Wait for all queued processing to complete. Args: timeout: Wait timeout in seconds + progress_callback: Optional callback invoked each poll iteration with + queue status dicts. Returns: Queue status """ qm = get_queue_manager() + + if progress_callback is not None: + def adapted(statuses): + progress_callback({ + name: { + "pending": s.pending, + "in_progress": s.in_progress, + "processed": s.processed, + "error_count": s.error_count, + "total": s.total, + } + for name, s in statuses.items() + }) + else: + adapted = None + try: - status = await qm.wait_complete(timeout=timeout) + status = await qm.wait_complete( + timeout=timeout, progress_callback=adapted + ) except TimeoutError as exc: raise DeadlineExceededError("queue processing", timeout) from exc return { diff --git a/openviking/storage/queuefs/named_queue.py b/openviking/storage/queuefs/named_queue.py index 6baeb398..a7578de0 100644 --- a/openviking/storage/queuefs/named_queue.py +++ b/openviking/storage/queuefs/named_queue.py @@ -42,6 +42,11 @@ def has_errors(self) -> bool: def is_complete(self) -> bool: return self.pending == 0 and self.in_progress == 0 + @property + def total(self) -> int: + """Total number of items across all states.""" + return self.processed + self.error_count + self.pending + self.in_progress + class EnqueueHookBase(abc.ABC): """Enqueue hook base class. diff --git a/openviking/storage/queuefs/queue_manager.py b/openviking/storage/queuefs/queue_manager.py index 8832f8b0..bc40489c 100644 --- a/openviking/storage/queuefs/queue_manager.py +++ b/openviking/storage/queuefs/queue_manager.py @@ -294,12 +294,16 @@ async def wait_complete( queue_name: Optional[str] = None, timeout: Optional[float] = None, poll_interval: float = 0.5, + progress_callback=None, ) -> Dict[str, QueueStatus]: """Wait for completion and return final status.""" start = time.time() while True: - if await self.is_all_complete(queue_name): - return await self.check_status(queue_name) + statuses = await self.check_status(queue_name) + if all(s.is_complete for s in statuses.values()): + return statuses + if progress_callback is not None: + progress_callback(statuses) if timeout and (time.time() - start) > timeout: raise TimeoutError(f"Queue processing not complete after {timeout}s") await asyncio.sleep(poll_interval) diff --git a/openviking/sync_client.py b/openviking/sync_client.py index 632c9d7a..348d999f 100644 --- a/openviking/sync_client.py +++ b/openviking/sync_client.py @@ -4,7 +4,7 @@ Synchronous OpenViking client implementation. """ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional if TYPE_CHECKING: from openviking.session import Session @@ -159,9 +159,23 @@ def rm(self, uri: str, recursive: bool = False) -> None: """Delete resource""" return run_async(self._async_client.rm(uri, recursive)) - def wait_processed(self, timeout: float = None) -> Dict[str, Any]: - """Wait for all async operations to complete""" - return run_async(self._async_client.wait_processed(timeout)) + def wait_processed( + self, + timeout: float = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> None: + """Wait for all async operations to complete. + + Args: + timeout: Wait timeout in seconds. + progress_callback: Optional callback invoked each poll iteration with + queue status dicts. + """ + return run_async( + self._async_client.wait_processed( + timeout, progress_callback=progress_callback + ) + ) def grep(self, uri: str, pattern: str, case_insensitive: bool = False) -> Dict: """Content search""" diff --git a/openviking_cli/client/base.py b/openviking_cli/client/base.py index b72d4f3a..da87a72d 100644 --- a/openviking_cli/client/base.py +++ b/openviking_cli/client/base.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union class BaseClient(ABC): @@ -53,8 +53,18 @@ async def add_skill( ... @abstractmethod - async def wait_processed(self, timeout: Optional[float] = None) -> Dict[str, Any]: - """Wait for all processing to complete.""" + async def wait_processed( + self, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> Dict[str, Any]: + """Wait for all processing to complete. + + Args: + timeout: Wait timeout in seconds. + progress_callback: Optional callback invoked each poll iteration with + queue status dicts. + """ ... # ============= File System ============= diff --git a/openviking_cli/client/http.py b/openviking_cli/client/http.py index a8ae3726..3e597a67 100644 --- a/openviking_cli/client/http.py +++ b/openviking_cli/client/http.py @@ -5,7 +5,7 @@ Implements BaseClient interface using HTTP calls to OpenViking Server. """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import httpx @@ -260,8 +260,15 @@ async def add_skill( ) return self._handle_response(response) - async def wait_processed(self, timeout: Optional[float] = None) -> Dict[str, Any]: - """Wait for all processing to complete.""" + async def wait_processed( + self, + timeout: Optional[float] = None, + progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> Dict[str, Any]: + """Wait for all processing to complete. + + Note: progress_callback is not supported in HTTP mode and will be ignored. + """ response = await self._http.post( "/api/v1/system/wait", json={"timeout": timeout}, diff --git a/pyproject.toml b/pyproject.toml index dd777a68..0ce58c0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ python_classes = ["Test*"] python_functions = ["test_*"] asyncio_mode = "auto" addopts = "-v --cov=openviking --cov-report=term-missing" +cache_dir = ".pytest_cache_local" [tool.ruff] line-length = 100 diff --git a/tests/client/test_wait_progress.py b/tests/client/test_wait_progress.py new file mode 100644 index 00000000..d32b3b77 --- /dev/null +++ b/tests/client/test_wait_progress.py @@ -0,0 +1,211 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for progress_callback support in wait_processed() chain.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.storage.queuefs.named_queue import QueueStatus + + +# ============= TestQueueStatusTotal ============= + + +class TestQueueStatusTotal: + """Tests for QueueStatus.total property.""" + + def test_total_all_zeros(self): + status = QueueStatus() + assert status.total == 0 + + def test_total_only_pending(self): + status = QueueStatus(pending=5) + assert status.total == 5 + + def test_total_only_processed(self): + status = QueueStatus(processed=10) + assert status.total == 10 + + def test_total_mixed(self): + status = QueueStatus(pending=2, in_progress=3, processed=4, error_count=1) + assert status.total == 10 + + def test_total_with_errors(self): + status = QueueStatus(processed=7, error_count=3) + assert status.total == 10 + + +# ============= TestQueueManagerProgressCallback ============= + + +class TestQueueManagerProgressCallback: + """Tests for QueueManager.wait_complete() progress_callback.""" + + @pytest.mark.asyncio + async def test_callback_called_with_statuses(self): + """progress_callback should be called with statuses dict each poll iteration.""" + from openviking.storage.queuefs.queue_manager import QueueManager + + qm = QueueManager.__new__(QueueManager) + qm._queues = {} + qm._started = True + qm._agfs = None + + call_count = 0 + statuses_in_progress = QueueStatus(pending=3, in_progress=1, processed=1) + statuses_complete = QueueStatus(pending=0, in_progress=0, processed=5) + + async def mock_check_status(queue_name=None): + nonlocal call_count + call_count += 1 + if call_count < 3: + return {"Embedding": statuses_in_progress} + return {"Embedding": statuses_complete} + + qm.check_status = mock_check_status + + callback = MagicMock() + result = await qm.wait_complete(poll_interval=0.01, progress_callback=callback) + + # callback should be called for non-complete iterations (2 times) + assert callback.call_count == 2 + # Each call should receive the statuses dict + for call_args in callback.call_args_list: + assert "Embedding" in call_args[0][0] + + # Final result should be the complete statuses + assert result["Embedding"].is_complete + + @pytest.mark.asyncio + async def test_none_callback_no_error(self): + """None progress_callback should not cause errors.""" + from openviking.storage.queuefs.queue_manager import QueueManager + + qm = QueueManager.__new__(QueueManager) + qm._queues = {} + qm._started = True + qm._agfs = None + + statuses_complete = QueueStatus(pending=0, in_progress=0, processed=5) + + async def mock_check_status(queue_name=None): + return {"Embedding": statuses_complete} + + qm.check_status = mock_check_status + + result = await qm.wait_complete( + poll_interval=0.01, progress_callback=None + ) + assert result["Embedding"].is_complete + + @pytest.mark.asyncio + async def test_timeout_with_callback(self): + """TimeoutError should still be raised even with progress_callback.""" + from openviking.storage.queuefs.queue_manager import QueueManager + + qm = QueueManager.__new__(QueueManager) + qm._queues = {} + qm._started = True + qm._agfs = None + + statuses_in_progress = QueueStatus(pending=3, in_progress=1) + + async def mock_check_status(queue_name=None): + return {"Embedding": statuses_in_progress} + + qm.check_status = mock_check_status + callback = MagicMock() + + with pytest.raises(TimeoutError): + await qm.wait_complete( + timeout=0.05, + poll_interval=0.01, + progress_callback=callback, + ) + + # callback should have been called at least once + assert callback.call_count >= 1 + + +# ============= TestResourceServiceProgressCallback ============= + + +class TestResourceServiceProgressCallback: + """Tests for ResourceService.wait_processed() progress_callback adapter.""" + + @pytest.mark.asyncio + async def test_adapter_converts_to_dict(self): + """progress_callback adapter should convert QueueStatus to dict with total field.""" + from openviking.service.resource_service import ResourceService + + service = ResourceService() + + captured = [] + + def user_callback(statuses_dict): + captured.append(statuses_dict) + + statuses_in_progress = QueueStatus( + pending=2, in_progress=1, processed=3, error_count=1 + ) + statuses_complete = QueueStatus( + pending=0, in_progress=0, processed=6, error_count=1 + ) + + call_count = 0 + + async def mock_wait_complete(timeout=None, progress_callback=None): + nonlocal call_count + # Simulate one in-progress poll before completion + if progress_callback is not None: + progress_callback({"Embedding": statuses_in_progress}) + return {"Embedding": statuses_complete} + + with patch( + "openviking.service.resource_service.get_queue_manager" + ) as mock_get_qm: + mock_qm = MagicMock() + mock_qm.wait_complete = mock_wait_complete + mock_get_qm.return_value = mock_qm + + result = await service.wait_processed( + progress_callback=user_callback + ) + + # Verify the adapter converted QueueStatus to dict + assert len(captured) == 1 + emb = captured[0]["Embedding"] + assert emb["pending"] == 2 + assert emb["in_progress"] == 1 + assert emb["processed"] == 3 + assert emb["error_count"] == 1 + assert emb["total"] == 7 + + @pytest.mark.asyncio + async def test_backward_compat_no_callback(self): + """wait_processed() without progress_callback should work as before.""" + from openviking.service.resource_service import ResourceService + + service = ResourceService() + + statuses_complete = QueueStatus( + pending=0, in_progress=0, processed=5, error_count=0, errors=[] + ) + + async def mock_wait_complete(timeout=None, progress_callback=None): + assert progress_callback is None + return {"Embedding": statuses_complete} + + with patch( + "openviking.service.resource_service.get_queue_manager" + ) as mock_get_qm: + mock_qm = MagicMock() + mock_qm.wait_complete = mock_wait_complete + mock_get_qm.return_value = mock_qm + + result = await service.wait_processed() + + assert result["Embedding"]["processed"] == 5 + assert result["Embedding"]["error_count"] == 0