diff --git a/docker-compose.yml b/docker-compose.yml index bbb3718fecb..48de8683c0b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -135,6 +135,11 @@ services: FIDES__CONFIG_PATH: ${FIDES__CONFIG_PATH:-/fides/.fides/fides.toml} FIDES__LOGGING__COLORIZE: "True" FIDES__USER__ANALYTICS_OPT_OUT: "True" + # The default HTTP health check port is 9000, override it here to ensure that + # the override works as expected. + FIDES__CELERY__HEALTHCHECK_PORT: "9001" + expose: + - 9001 volumes: - type: bind source: ./ @@ -143,52 +148,16 @@ services: - /fides/src/fides.egg-info worker-privacy-preferences: - image: ethyca/fides:local + extends: + service: worker-other command: fides worker --queues=fides.privacy_preferences,fides.privacy_request_exports,fides.privacy_request_ingestion - depends_on: - redis: - condition: service_started - restart: always - healthcheck: - test: ["CMD", "celery", "-A", "fides.api.tasks", "inspect", "ping"] - start_period: 60s - interval: 20s - timeout: 5s - retries: 10 - environment: - FIDES__CONFIG_PATH: ${FIDES__CONFIG_PATH:-/fides/.fides/fides.toml} - FIDES__LOGGING__COLORIZE: "True" - FIDES__USER__ANALYTICS_OPT_OUT: "True" - volumes: - - type: bind - source: ./ - target: /fides - read_only: False - - /fides/src/fides.egg-info worker-dsr: - image: ethyca/fides:local - command: fides worker --queues=fides.dsr - depends_on: - redis: - condition: service_started - restart: always + extends: + service: worker-other healthcheck: - test: ["CMD", "celery", "-A", "fides.api.tasks", "inspect", "ping"] - start_period: 60s - interval: 20s - timeout: 5s - retries: 10 - environment: - FIDES__CONFIG_PATH: ${FIDES__CONFIG_PATH:-/fides/.fides/fides.toml} - FIDES__LOGGING__COLORIZE: "True" - FIDES__USER__ANALYTICS_OPT_OUT: "True" - volumes: - - type: bind - source: ./ - target: /fides - read_only: False - - /fides/src/fides.egg-info + test: [ "CMD", "curl", "-f", "http://localhost:9001/"] + command: fides worker --queues=fides.dsr redis: image: "redis:8.0-alpine" diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index 63b992673da..08b0fcf4cd9 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -14,6 +14,7 @@ ) from fides.api.db.session import get_db_engine, get_db_session +from fides.api.tasks import celery_healthcheck from fides.api.util.logger import setup as setup_logging from fides.config import CONFIG, FidesConfig @@ -102,6 +103,7 @@ def _create_celery(config: FidesConfig = CONFIG) -> Celery: ) app = Celery(__name__) + celery_healthcheck.register(app) # type: ignore celery_config: Dict[str, Any] = { # Defaults for the celery config @@ -112,6 +114,8 @@ def _create_celery(config: FidesConfig = CONFIG) -> Celery: # Ops requires this to route emails to separate queues "task_create_missing_queues": True, "task_default_queue": "fides", + "healthcheck_port": config.celery.healthcheck_port, + "healthcheck_ping_timeout": config.celery.healthcheck_ping_timeout, } celery_config.update(config.celery) diff --git a/src/fides/api/tasks/celery_healthcheck/__init__.py b/src/fides/api/tasks/celery_healthcheck/__init__.py new file mode 100644 index 00000000000..b2c85f80985 --- /dev/null +++ b/src/fides/api/tasks/celery_healthcheck/__init__.py @@ -0,0 +1,11 @@ +# fmt: off +# type: ignore +# pylint: skip-file +# isort:off + + +from .server import HealthCheckServer + + +def register(celery_app): + celery_app.steps["worker"].add(HealthCheckServer) diff --git a/src/fides/api/tasks/celery_healthcheck/server.py b/src/fides/api/tasks/celery_healthcheck/server.py new file mode 100644 index 00000000000..c21241201f8 --- /dev/null +++ b/src/fides/api/tasks/celery_healthcheck/server.py @@ -0,0 +1,116 @@ +import json +import threading +from http.server import HTTPServer, SimpleHTTPRequestHandler +from typing import Any, Optional + +from celery import bootsteps +from celery.worker import WorkController +from loguru import logger + +HEALTHCHECK_DEFAULT_PORT = 9000 +HEALTHCHECK_DEFAULT_PING_TIMEOUT = 2.0 +HEALTHCHECK_DEFAULT_HTTP_SERVER_SHUTDOWN_TIMEOUT = 2.0 + + +class HealthcheckHandler(SimpleHTTPRequestHandler): + """HTTP request handler with additional properties and functions""" + + def __init__( + self, parent: WorkController, healthcheck_ping_timeout: float, *args: Any + ): + self.parent = parent + self.healthcheck_ping_timeout = healthcheck_ping_timeout + super().__init__(*args) + + def do_GET(self) -> None: + """Handle GET requests""" + try: + try: + parent = self.parent + insp = parent.app.control.inspect( + destination=[parent.hostname], timeout=self.healthcheck_ping_timeout + ) + result = insp.ping() + + data = json.dumps({"status": "ok", "data": result}) + logger.debug(f"Healthcheck ping result: {data}") + + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(bytes(data, "utf-8")) + except Exception as e: + logger.warning(f"Healthcheck ping exception: {e}") + response = {"status": "error", "data": str(e)} + self.send_response(503) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(bytes(json.dumps(response), "utf-8")) + except Exception as ex: + logger.exception("HealthcheckHandler exception", exc_info=ex) + self.send_response(500) + + +class HealthCheckServer(bootsteps.StartStopStep): + # ignore kwargs type + def __init__(self, parent: WorkController, **kwargs): # type: ignore [arg-type, no-untyped-def] + self.thread: Optional[threading.Thread] = None + self.http_server: Optional[HTTPServer] = None + + self.parent = parent + + # config + self.healthcheck_port = int( + getattr(parent.app.conf, "healthcheck_port", HEALTHCHECK_DEFAULT_PORT) + ) + self.healthcheck_ping_timeout = float( + getattr( + parent.app.conf, + "healthcheck_ping_timeout", + HEALTHCHECK_DEFAULT_PING_TIMEOUT, + ) + ) + self.shutdown_timeout = float( + getattr( + parent.app.conf, + "shutdown_timeout", + HEALTHCHECK_DEFAULT_HTTP_SERVER_SHUTDOWN_TIMEOUT, + ) + ) + + super().__init__(**kwargs) + + # The mypy hints for an HTTP handler are strange, so ignoring them here + def http_handler(self, *args) -> None: # type: ignore [arg-type, no-untyped-def] + HealthcheckHandler(self.parent, self.healthcheck_ping_timeout, *args) + + def start(self, parent: WorkController) -> None: + # Ignore mypy hints here as the constructed object immediately handles the request + # (if you look in the source code for SimpleHTTPRequestHandler, specifically the finalize request method) + self.http_server = HTTPServer( + ("0.0.0.0", self.healthcheck_port), self.http_handler # type: ignore [arg-type] + ) + + self.thread = threading.Thread( + target=self.http_server.serve_forever, daemon=True + ) + self.thread.start() + + def stop(self, parent: WorkController) -> None: + if self.http_server is None: + logger.warning( + "Requested stop of HTTP healthcheck server, but no server was started" + ) + else: + logger.info( + f"Stopping health check server with a timeout of {self.shutdown_timeout} seconds" + ) + self.http_server.shutdown() + + # Really this should not happen if the HTTP server is None, but just in case, we should check. + if self.thread is None: + logger.warning("No thread in HTTP healthcheck server to shutdown...") + else: + self.thread.join(self.shutdown_timeout) + + logger.info(f"Health check server stopped on port {self.healthcheck_port}") diff --git a/src/fides/config/celery_settings.py b/src/fides/config/celery_settings.py index ea185e862a4..841ee60dbf2 100644 --- a/src/fides/config/celery_settings.py +++ b/src/fides/config/celery_settings.py @@ -27,6 +27,12 @@ class CelerySettings(FidesSettings): description="If true, tasks are executed locally instead of being sent to the queue. " "If False, tasks are sent to the queue.", ) + healthcheck_port: int = Field( + default=9000, description="The port to use for the health check endpoint" + ) + healthcheck_ping_timeout: float = Field( + default=2.0, description="The timeout in seconds for the health check ping" + ) model_config = SettingsConfigDict(env_prefix=ENV_PREFIX) diff --git a/tests/conftest.py b/tests/conftest.py index 3ca47f6dd55..b559fe4c9c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,7 +59,7 @@ from fides.api.schemas.messaging.messaging import MessagingServiceType from fides.api.schemas.privacy_request import PrivacyRequestStatus from fides.api.task.graph_runners import access_runner, consent_runner, erasure_runner -from fides.api.tasks import celery_app +from fides.api.tasks import celery_app, celery_healthcheck from fides.api.tasks.scheduled.scheduler import async_scheduler, scheduler from fides.api.util.cache import get_cache from fides.api.util.collection_util import Row @@ -765,15 +765,48 @@ def celery_enable_logging(): return True +# Register health check for workers @pytest.fixture(scope="session") -def celery_worker_parameters(): - """Configure celery worker parameters for testing. +def celery_session_app(celery_session_app): + celery_healthcheck.register(celery_session_app) + return celery_session_app + +# This is here because the test suite occasionally fails to teardown the +# Celery worker if it takes too long to terminate the worker thread. This +# will prevent that and, instead, log a warning +@pytest.fixture(scope="session") +def celery_session_worker( + request, + celery_session_app, + celery_includes, + celery_class_tasks, + celery_worker_pool, + celery_worker_parameters, +): + from celery.contrib.testing import worker + + for module in celery_includes: + celery_session_app.loader.import_task_module(module) + for class_task in celery_class_tasks: + celery_session_app.register_task(class_task) + + try: + + logger.info("Starting safe celery session worker...") + with worker.start_worker( + celery_session_app, + pool=celery_worker_pool, + shutdown_timeout=2.0, + **celery_worker_parameters, + ) as w: + try: + yield w + logger.info("Done with celery worker, trying to dispose of it..") + except RuntimeError: + logger.warning("Failed to dispose of the celery worker.") + except RuntimeError as re: + logger.warning("Failed to stop the celery worker: " + str(re)) - Increase shutdown_timeout to avoid flaky test failures when the worker - takes longer to shut down, especially during parallel test runs with pytest-xdist. - The CI environment can be slow, so we use a generous timeout. - """ - return {"shutdown_timeout": 180.0} @pytest.fixture(autouse=True, scope="session") @@ -880,7 +913,8 @@ def access_runner_tester( connection_configs, identity, session, - privacy_request_proceed=False, # This allows the DSR 3.0 Access Runner to be tested in isolation, to just test running the access graph without queuing the privacy request + privacy_request_proceed=False, + # This allows the DSR 3.0 Access Runner to be tested in isolation, to just test running the access graph without queuing the privacy request ) except PrivacyRequestExit: # DSR 3.0 intentionally raises a PrivacyRequestExit status while it waits for @@ -938,7 +972,8 @@ def consent_runner_tester( connection_configs, identity, session, - privacy_request_proceed=False, # This allows the DSR 3.0 Consent Runner to be tested in isolation, to just test running the consent graph without queuing the privacy request + privacy_request_proceed=False, + # This allows the DSR 3.0 Consent Runner to be tested in isolation, to just test running the consent graph without queuing the privacy request ) except PrivacyRequestExit: # DSR 3.0 intentionally raises a PrivacyRequestExit status while it waits for diff --git a/tests/task/test_healthcheck_server.py b/tests/task/test_healthcheck_server.py new file mode 100644 index 00000000000..171ffba9680 --- /dev/null +++ b/tests/task/test_healthcheck_server.py @@ -0,0 +1,36 @@ +import pytest + +import pytest +import requests +from loguru import logger + + +class TestCeleryHealthCheckServer: + def test_responds_to_ping_properly(self, celery_session_app, celery_session_worker): + try: + response = requests.get("http://127.0.0.1:9000/") + assert response.status_code == 200 + assert response.json()["status"] == "ok" + except requests.exceptions.ConnectionError: + pytest.fail("Connection error") + + +class TestCeleryHealthCheckWorker: + @pytest.fixture(autouse=True) + def setup_teardown(self): + yield + with pytest.raises(Exception): + requests.get("http://127.0.0.1:9000/", timeout=1) + + def test_shutdown_gracefully(self, celery_session_app, celery_session_worker): + try: + logger.info("Shutdown gracefully") + celery_session_worker.stop() + logger.info("Shutdown gracefully finished") + except Exception: + pytest.fail("Failed to stop health check server") + + + + +