From a5f74c30628acb30f2925847862553169ccb7650 Mon Sep 17 00:00:00 2001 From: extreme4all <40169115+extreme4all@users.noreply.github.com> Date: Thu, 5 Feb 2026 18:29:27 +0100 Subject: [PATCH] Simplify ConsumerWorker queue usage --- .../event_queue/core/event_queue.py | 7 +- components/bot_detector/worker/__init__.py | 17 +- components/bot_detector/worker/core.py | 240 ++++++++++-------- components/bot_detector/worker/interface.py | 12 +- test/components/bot_detector/worker/kafka.py | 30 ++- .../bot_detector/worker/test_worker.py | 114 +++++---- 6 files changed, 248 insertions(+), 172 deletions(-) diff --git a/components/bot_detector/event_queue/core/event_queue.py b/components/bot_detector/event_queue/core/event_queue.py index d43ed42..dbfdb51 100644 --- a/components/bot_detector/event_queue/core/event_queue.py +++ b/components/bot_detector/event_queue/core/event_queue.py @@ -32,8 +32,8 @@ async def start(self): async def stop(self): await self._backend.stop() - async def put(self, message: list[T]): - await self._backend.put(message) + async def put(self, message: list[T]) -> Optional[Exception]: + return await self._backend.put(message) class QueueConsumer(Generic[T]): @@ -64,6 +64,9 @@ async def get_one(self) -> Optional[T] | Exception: async def get_many(self, count: int) -> list[T] | Exception: return await self._backend.get_many(count) + async def commit(self) -> Optional[Exception]: + return await self._backend.commit() + class Queue(QueueConsumer[T], QueueProducer[T]): """ diff --git a/components/bot_detector/worker/__init__.py b/components/bot_detector/worker/__init__.py index 0cd4f50..51c595d 100644 --- a/components/bot_detector/worker/__init__.py +++ b/components/bot_detector/worker/__init__.py @@ -1,4 +1,15 @@ -from .core import BaseWorker -from .interface import WorkerInterface +from .core import BaseWorker, ConsumerWorker, ProducerWorker +from .interface import ( + ConsumerWorkerInterface, + ProducerWorkerInterface, + WorkerInterface, +) -__all__ = ["BaseWorker", "WorkerInterface"] +__all__ = [ + "BaseWorker", + "ConsumerWorker", + "ConsumerWorkerInterface", + "ProducerWorker", + "ProducerWorkerInterface", + "WorkerInterface", +] diff --git a/components/bot_detector/worker/core.py b/components/bot_detector/worker/core.py index ef9ef0b..4f0087b 100644 --- a/components/bot_detector/worker/core.py +++ b/components/bot_detector/worker/core.py @@ -2,138 +2,147 @@ import logging from typing import Any, Generic, TypeVar -from bot_detector.kafka import ConsumerInterface, ProducerInterface +from bot_detector.event_queue.core import Queue, QueueProducer from bot_detector.wide_event import EventLoggerInterface, WideEventLogger from pydantic import BaseModel -from .interface import WorkerInterface +from .interface import ConsumerWorkerInterface, ProducerWorkerInterface T = TypeVar("T", bound=BaseModel) -class BaseWorker(Generic[T], WorkerInterface[T]): +class BaseWorker(Generic[T]): """Generic worker with minimal boilerplate and integrated logging.""" EMPTY_MESSAGE_SLEEP = 10 - PRODUCE_RETRY_DELAY = 5 - PRODUCE_MAX_RETRY = 3 def __init__( self, - consumer: ConsumerInterface[T], - producer: ProducerInterface[T], - max_messages: int = 10_000, - max_interval_ms: int = 5_000, - batch_processing: bool = False, wide_event: EventLoggerInterface = WideEventLogger(sample_ratio=0.1), logger_name: str | None = None, ) -> None: - self._consumer = consumer - self._producer = producer - self._max_messages = max_messages - self._max_interval_ms = max_interval_ms - self._batch_processing = batch_processing self._wide_event = wide_event self._logger = logging.getLogger(logger_name or self.__class__.__name__) self._stop_event = asyncio.Event() - async def on_message(self, message: T) -> bool: - """Override with single message logic. Return True for success.""" - return True - - async def on_message_batch(self, messages: list[T]) -> bool: - """Override with batch message logic. Return True if all succeed.""" - for msg in messages: - if not await self.on_message(msg): - return False - return True - - async def start(self) -> None: - await self._consumer.start() - await self._producer.start() - await self._run() - - async def stop(self) -> None: - self._stop_event.set() - await self._consumer.stop() - await self._producer.stop() - # ------------------------ # Hooks # ------------------------ - async def _consume_error_hook(self, errors: list[str]): - self._add_context({"error": {"consumer_errors": errors[:5]}}) + async def _consume_error_hook(self, errors: list[Exception]): + error_messages = [str(error) for error in errors[:5]] + self._add_context({"error": {"consumer_errors": error_messages}}) async def _empty_message_hook(self): self._add_context({"_run": {"status": "empty"}}) await asyncio.sleep(self.EMPTY_MESSAGE_SLEEP) - async def _failed_on_message_hook(self): - self._add_context({"_run": {"status": "failed"}}) + async def _failed_on_message_hook(self, error: Exception): + self._add_context({"_run": {"status": "failed", "error": str(error)}}) async def _success_on_message_hook(self): self._add_context({"_run": {"status": "success"}}) # ------------------------ - # Retry logic + # WideEvent logging # ------------------------ + def _set_context(self, data: dict) -> Any: + return self._wide_event.set(data) + + def _add_context(self, data: dict) -> None: + self._wide_event.add(data) + + def _get_context(self) -> dict: + return self._wide_event.get() + + def _reset_context(self, token: Any) -> None: + self._wide_event.reset(token) + + def _log(self) -> None: + context = self._get_context() + if "error" in context: + self._logger.error(context) + elif self._wide_event.sample(): + self._logger.info(context) + + +class ConsumerWorker(BaseWorker[T], ConsumerWorkerInterface[T]): + """Worker that consumes messages from a queue and processes them.""" + + MAX_MESSAGES = 10_000 + + def __init__( + self, + queue: Queue[T], + batch_processing: bool = False, + wide_event: EventLoggerInterface = WideEventLogger(sample_ratio=0.1), + logger_name: str | None = None, + ) -> None: + super().__init__(wide_event=wide_event, logger_name=logger_name) + self._queue = queue + self._batch_processing = batch_processing + + async def on_message(self, message: T) -> Exception | None: + """Override with single message logic. Return Exception on failure.""" + return None + + async def on_message_batch(self, messages: list[T]) -> Exception | None: + """Override with batch message logic. Return Exception on failure.""" + for msg in messages: + error = await self.on_message(msg) + if error: + return error + return None + + async def start(self) -> None: + await self._queue.start() + await self._run() + + async def stop(self) -> None: + self._stop_event.set() + await self._queue.stop() + async def _produce_failed_messages(self, batch: list[T]) -> None: - errors: list[str] = [] - - for message in batch: - retry_count = 0 - while retry_count < self.PRODUCE_MAX_RETRY: - try: - await self._producer.produce_one(message=message) - break - except Exception as e: - retry_count += 1 - if retry_count >= self.PRODUCE_MAX_RETRY: - errors.append(str(e)) - break - await asyncio.sleep(self.PRODUCE_RETRY_DELAY) - - if errors: - self._add_context({"error": {"produce_failed_messages": errors[:5]}}) + error = await self._queue.put(batch) + if error: + self._add_context({"error": {"produce_failed_messages": [str(error)]}}) - # ------------------------ - # Core processing loops - # ------------------------ async def _run_one(self) -> list[T]: failed: list[T] = [] - message, consume_error = await self._consumer.consume_one() - if consume_error: - await self._consume_error_hook([consume_error]) - if message is None: + result = await self._queue.get_one() + if isinstance(result, Exception): + await self._consume_error_hook([result]) await self._empty_message_hook() + return failed + if result is None: + await self._empty_message_hook() + return failed + message = result + error = await self.on_message(message) + if error is None: + await self._success_on_message_hook() else: - if await self.on_message(message): - await self._success_on_message_hook() - else: - failed.append(message) - await self._failed_on_message_hook() + failed.append(message) + await self._failed_on_message_hook(error) return failed async def _run_many(self) -> list[T]: failed_messages: list[T] = [] - batch, errors = await self._consumer.consume_many( - max_records=self._max_messages, timeout_ms=self._max_interval_ms - ) - - if errors: - await self._consume_error_hook(errors) - - if not batch: + result = await self._queue.get_many(self.MAX_MESSAGES) + if isinstance(result, Exception): + await self._consume_error_hook([result]) await self._empty_message_hook() - return [] - + return failed_messages + if not result: + await self._empty_message_hook() + return failed_messages + batch = result self._add_context({"_run": {"batch_size": len(batch)}}) - - if await self.on_message_batch(batch): + error = await self.on_message_batch(batch) + if error is None: await self._success_on_message_hook() else: failed_messages.extend(batch) - await self._failed_on_message_hook() + await self._failed_on_message_hook(error) return failed_messages @@ -146,7 +155,9 @@ async def _run(self) -> None: else: failed_messages = await self._run_one() await self._produce_failed_messages(failed_messages) - await self._consumer.commit() + commit_error = await self._queue.commit() + if commit_error: + await self._consume_error_hook([commit_error]) except Exception as e: self._add_context({"error": {"_run_exception": str(e)}}) await asyncio.sleep(15) @@ -154,24 +165,51 @@ async def _run(self) -> None: self._log() self._reset_context(token) - # ------------------------ - # WideEvent logging - # ------------------------ - def _set_context(self, data: dict) -> Any: - return self._wide_event.set(data) - def _add_context(self, data: dict) -> None: - self._wide_event.add(data) +class ProducerWorker(BaseWorker[T], ProducerWorkerInterface[T]): + """Worker that builds and produces messages to a queue.""" - def _get_context(self) -> dict: - return self._wide_event.get() + def __init__( + self, + queue: QueueProducer[T] | Queue[T], + wide_event: EventLoggerInterface = WideEventLogger(sample_ratio=0.1), + logger_name: str | None = None, + ) -> None: + super().__init__(wide_event=wide_event, logger_name=logger_name) + self._queue = queue - def _reset_context(self, token: Any) -> None: - self._wide_event.reset(token) + async def build_messages(self) -> list[T] | Exception | None: + """Override to return a batch of messages to produce.""" + return None - def _log(self) -> None: - context = self._get_context() - if "error" in context: - self._logger.error(context) - elif self._wide_event.sample(): - self._logger.info(context) + async def start(self) -> None: + await self._queue.start() + await self._run() + + async def stop(self) -> None: + self._stop_event.set() + await self._queue.stop() + + async def _run(self) -> None: + while not self._stop_event.is_set(): + token = self._set_context(data={}) + try: + result = await self.build_messages() + if isinstance(result, Exception): + await self._failed_on_message_hook(result) + await self._empty_message_hook() + continue + if not result: + await self._empty_message_hook() + continue + error = await self._queue.put(result) + if error: + self._add_context( + {"error": {"produce_failed_messages": [str(error)]}} + ) + except Exception as e: + self._add_context({"error": {"_run_exception": str(e)}}) + await asyncio.sleep(15) + finally: + self._log() + self._reset_context(token) diff --git a/components/bot_detector/worker/interface.py b/components/bot_detector/worker/interface.py index 21f097d..f998d75 100644 --- a/components/bot_detector/worker/interface.py +++ b/components/bot_detector/worker/interface.py @@ -5,9 +5,15 @@ T = TypeVar("T", bound=BaseModel) -class WorkerInterface(Protocol, Generic[T]): # pragma: no cover +class WorkerInterface(Protocol): # pragma: no cover async def start(self) -> None: ... async def stop(self) -> None: ... - async def on_message(self, message: T) -> bool: ... - async def on_message_batch(self, messages: list[T]) -> bool: ... + +class ConsumerWorkerInterface(WorkerInterface, Generic[T]): # pragma: no cover + async def on_message(self, message: T) -> Exception | None: ... + async def on_message_batch(self, messages: list[T]) -> Exception | None: ... + + +class ProducerWorkerInterface(WorkerInterface, Generic[T]): # pragma: no cover + async def build_messages(self) -> list[T] | Exception | None: ... diff --git a/test/components/bot_detector/worker/kafka.py b/test/components/bot_detector/worker/kafka.py index 1f34e15..9494f78 100644 --- a/test/components/bot_detector/worker/kafka.py +++ b/test/components/bot_detector/worker/kafka.py @@ -15,7 +15,7 @@ class DummyConsumer: Simulates a Kafka consumer. Uses a list of dicts to simulate messages and errors. Each item in `queue` is either: {"message": TestMessage(...)} -> a normal message - {"error": "some error"} -> a consumer error + {"error": Exception(...)} -> a consumer error """ def __init__(self, queue: list[dict] | None = None) -> None: @@ -25,25 +25,21 @@ def __init__(self, queue: list[dict] | None = None) -> None: self.commit = AsyncMock() self.get_consumer = AsyncMock() - async def consume_one(self) -> tuple[TestMessage | None, str | None]: + async def get_one(self) -> TestMessage | None | Exception: if not self.queue: - return None, None + return None item = self.queue.pop(0) - return item.get("message"), item.get("error") + return item.get("message") or item.get("error") - async def consume_many( - self, max_records: int, timeout_ms: int - ) -> tuple[list[TestMessage], list[str]]: + async def get_many(self, count: int) -> list[TestMessage] | Exception: batch: list[TestMessage] = [] - errors: list[str] = [] - - for _ in range(min(max_records, len(self.queue))): + for _ in range(min(count, len(self.queue))): item = self.queue.pop(0) if "message" in item: batch.append(item["message"]) elif "error" in item: - errors.append(item["error"]) - return batch, errors + return item["error"] + return batch async def get_lag(self) -> int: return 0 @@ -56,4 +52,12 @@ def __init__(self) -> None: self.start = AsyncMock() self.stop = AsyncMock() self.get_producer = AsyncMock() - self.produce_one = AsyncMock() + self.put = AsyncMock() + + +class DummyQueue(DummyConsumer, DummyProducer): + """Simulates a Queue with both consumer and producer capabilities.""" + + def __init__(self, queue: list[dict] | None = None) -> None: + DummyConsumer.__init__(self, queue=queue) + DummyProducer.__init__(self) diff --git a/test/components/bot_detector/worker/test_worker.py b/test/components/bot_detector/worker/test_worker.py index 2a88577..cf95ddd 100644 --- a/test/components/bot_detector/worker/test_worker.py +++ b/test/components/bot_detector/worker/test_worker.py @@ -3,40 +3,38 @@ from unittest.mock import AsyncMock, Mock import pytest -from bot_detector.worker import BaseWorker +from bot_detector.worker import ConsumerWorker, ProducerWorker from test.components.bot_detector.worker.kafka import ( - DummyConsumer, DummyProducer, + DummyQueue, TestMessage, ) # --- Test worker --- -class TestWorker(BaseWorker[TestMessage]): +class TestWorker(ConsumerWorker[TestMessage]): # overwrite the sleeps to speed up tests EMPTY_MESSAGE_SLEEP = 0 - PRODUCE_RETRY_DELAY = 0 - PRODUCE_MAX_RETRY = 3 def __init__( self, - *args: Any, - on_message_result: Callable[[TestMessage], bool] | None = None, - on_batch_result: Callable[[list[TestMessage]], bool] | None = None, + queue: DummyQueue, + on_message_result: Callable[[TestMessage], Exception | None] | None = None, + on_batch_result: Callable[[list[TestMessage]], Exception | None] | None = None, **kwargs: Any, ) -> None: - super().__init__(*args, **kwargs) + super().__init__(queue=queue, **kwargs) self.seen_messages: list[TestMessage] = [] self.seen_batches: list[list[TestMessage]] = [] - self._on_message_result = on_message_result or (lambda _: True) - self._on_batch_result = on_batch_result or (lambda _: True) + self._on_message_result = on_message_result or (lambda _: None) + self._on_batch_result = on_batch_result or (lambda _: None) - async def on_message(self, message: TestMessage) -> bool: + async def on_message(self, message: TestMessage) -> Exception | None: self.seen_messages.append(message) return self._on_message_result(message) - async def on_message_batch(self, messages: list[TestMessage]) -> bool: + async def on_message_batch(self, messages: list[TestMessage]) -> Exception | None: self.seen_batches.append(messages) return self._on_batch_result(messages) @@ -50,12 +48,9 @@ async def test_worker_process_message_success(): ] consumer_queue = [{"message": m} for m in messages] - consumer = DummyConsumer(queue=consumer_queue) - producer = DummyProducer() - + queue = DummyQueue(queue=consumer_queue) worker = TestWorker( - consumer=consumer, - producer=producer, + queue=queue, batch_processing=False, ) @@ -74,8 +69,8 @@ async def test_worker_process_message_success(): await task assert worker.seen_messages == messages - producer.produce_one.assert_not_called() - consumer.commit.assert_called() + queue.put.assert_not_called() + queue.commit.assert_called() logger.error.assert_not_called() @@ -83,18 +78,15 @@ async def test_worker_process_message_success(): async def test_worker_retries_failed_message(monkeypatch): messages = [TestMessage(id="1", data="a"), TestMessage(id="2", data="b")] consumer_queue = [{"message": m} for m in messages] - consumer = DummyConsumer(queue=consumer_queue) - producer = DummyProducer() + queue = DummyQueue(queue=consumer_queue) # Fail only message with id="2" - def fail_on_message(msg: TestMessage) -> bool: - return msg.id != "2" + def fail_on_message(msg: TestMessage) -> Exception | None: + if msg.id == "2": + return ValueError("failed") + return None - worker = TestWorker( - consumer=consumer, - producer=producer, - on_message_result=fail_on_message, - ) + worker = TestWorker(queue=queue, on_message_result=fail_on_message) # Run the worker task = asyncio.create_task(worker.start()) @@ -105,22 +97,19 @@ def fail_on_message(msg: TestMessage) -> bool: await worker.stop() await task - producer.produce_one.assert_called_once_with(message=messages[1]) - consumer.commit.assert_called() + queue.put.assert_called_once_with([messages[1]]) + queue.commit.assert_called() @pytest.mark.asyncio async def test_worker_batch_processing_retries_all(monkeypatch): messages = [TestMessage(id="1", data="a"), TestMessage(id="2", data="b")] consumer_queue = [{"message": m} for m in messages] - consumer = DummyConsumer(queue=consumer_queue) - producer = DummyProducer() - + queue = DummyQueue(queue=consumer_queue) worker = TestWorker( - consumer=consumer, - producer=producer, + queue=queue, batch_processing=True, - on_batch_result=lambda _: False, + on_batch_result=lambda _: ValueError("failed"), ) # Run the worker @@ -134,16 +123,15 @@ async def test_worker_batch_processing_retries_all(monkeypatch): assert worker.seen_batches == [messages] assert worker.seen_messages == [] - assert producer.produce_one.call_count == len(messages) - consumer.commit.assert_called() + queue.put.assert_called_once_with(messages) + queue.commit.assert_called() @pytest.mark.asyncio async def test_worker_logs_consumer_errors(monkeypatch): - consumer_queue = [{"error": "oops"}] - consumer = DummyConsumer(queue=consumer_queue) - producer = DummyProducer() - worker = TestWorker(consumer=consumer, producer=producer) + consumer_queue = [{"error": ValueError("oops")}] + queue = DummyQueue(queue=consumer_queue) + worker = TestWorker(queue=queue) logger = Mock() logger.info = Mock() @@ -163,16 +151,42 @@ async def test_worker_logs_consumer_errors(monkeypatch): @pytest.mark.asyncio async def test_worker_start_stop_calls_dependencies(): - consumer = DummyConsumer() - producer = DummyProducer() - worker = TestWorker(consumer=consumer, producer=producer) + queue = DummyQueue() + worker = TestWorker(queue=queue) worker._run = AsyncMock() # prevent actual loop await worker.start() await worker.stop() - consumer.start.assert_called_once() - producer.start.assert_called_once() + queue.start.assert_called_once() worker._run.assert_called_once() - consumer.stop.assert_called_once() - producer.stop.assert_called_once() + queue.stop.assert_called_once() + + +class TestProducerWorker(ProducerWorker[TestMessage]): + def __init__( + self, + queue: DummyProducer, + batch: list[TestMessage] | None = None, + ) -> None: + super().__init__(queue=queue) + self._batch = batch + + async def build_messages(self) -> list[TestMessage] | Exception | None: + return self._batch + + +@pytest.mark.asyncio +async def test_producer_worker_puts_batch(): + producer = DummyProducer() + worker = TestProducerWorker( + queue=producer, + batch=[TestMessage(id="1", data="a")], + ) + worker._empty_message_hook = AsyncMock() + task = asyncio.create_task(worker.start()) + while not producer.put.called: + await asyncio.sleep(0) + await worker.stop() + await task + producer.put.assert_called_once()