From 47fc495c5486a48e2979d2546d94e514e94c2723 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:13:46 -0700 Subject: [PATCH 1/7] restructure tests into unit integration e2e tiers --- .github/workflows/ci.yml | 4 +- MCP_SMOKE_DEBUG.md | 183 -------- TEST_RESTRUCTURE_PLAN.md | 442 ++++++++++++++++++ conftest.py | 9 +- docs/testing.md | 23 +- pytest.ini | 7 +- tests/driver/__init__.py | 1 - tests/e2e/__init__.py | 0 tests/e2e/graph/__init__.py | 0 .../graph}/data/longmemeval_data/README.md | 0 .../longmemeval_data/longmemeval_oracle.json | 0 tests/{evals => e2e/graph}/eval_cli.py | 2 +- .../graph/eval_graph_building.py} | 26 +- tests/e2e/graph/test_eval_harness.py | 9 + tests/{evals => e2e/graph}/utils.py | 0 tests/evals/pytest.ini | 4 - tests/helpers/__init__.py | 0 tests/helpers/embeddings.py | 61 +++ tests/helpers/services.py | 280 +++++++++++ tests/helpers_test.py | 350 -------------- tests/integration/__init__.py | 0 tests/integration/conftest.py | 6 + tests/integration/core/__init__.py | 0 tests/integration/core/shared/__init__.py | 0 .../core/shared/test_entity_exclusion.py} | 43 +- .../core/shared/test_graphium_bootstrap.py} | 2 +- .../core/shared/test_ingestion_pipeline.py} | 8 +- .../core/shared/test_repository_edges.py} | 78 +--- .../core/shared/test_repository_nodes.py} | 24 +- tests/integration/falkordb/__init__.py | 0 .../integration/falkordb/test_placeholder.py | 10 + tests/integration/kuzu/__init__.py | 0 tests/integration/kuzu/test_placeholder.py | 10 + .../llm_client/test_anthropic_client.py} | 0 tests/integration/neo4j/__init__.py | 0 tests/integration/neo4j/test_placeholder.py | 10 + tests/integration/shared/__init__.py | 0 tests/integration/shared/fixtures_services.py | 39 ++ tests/unit/__init__.py | 0 tests/unit/cross_encoder/__init__.py | 0 .../cross_encoder/test_bge_reranker_client.py | 0 .../test_gemini_reranker_client.py | 0 tests/unit/drivers/__init__.py | 0 .../drivers}/test_falkordb_driver.py | 0 tests/unit/embedder/__init__.py | 0 .../{ => unit}/embedder/embedder_fixtures.py | 0 .../embedder/test_embeddinggemma.py | 0 tests/{ => unit}/embedder/test_gemini.py | 5 +- tests/{ => unit}/embedder/test_openai.py | 3 +- tests/{ => unit}/embedder/test_voyage.py | 3 +- tests/unit/llm_client/__init__.py | 0 .../llm_client/test_anthropic_client.py | 0 tests/{ => unit}/llm_client/test_client.py | 0 tests/{ => unit}/llm_client/test_errors.py | 0 .../llm_client/test_gemini_client.py | 0 .../{ => unit}/llm_client/test_groq_client.py | 0 .../llm_client/test_litellm_client.py | 0 .../llm_client/test_pydantic_ai_adapter.py | 0 .../llm_client/test_structured_output.py | 0 tests/unit/mcp/__init__.py | 0 tests/{ => unit}/mcp/test_episode_queue.py | 0 tests/unit/orchestration/__init__.py | 0 tests/{ => unit}/orchestration/test_bulk.py | 0 .../orchestration/test_bulk_serialization.py | 0 .../test_episode_orchestrator.py | 0 .../test_initializer_factory.py} | 0 .../test_node_operations_sequence.py | 0 tests/unit/providers/__init__.py | 0 tests/{ => unit}/providers/test_factory.py | 0 tests/unit/search/__init__.py | 0 .../search/test_edge_search_orchestration.py | 0 tests/unit/search/test_lucene_utils.py | 17 + .../{ => unit}/search/test_search_filters.py | 0 .../{ => unit}/search/test_search_helpers.py | 0 .../search/test_search_utils_edges.py | 0 .../search/test_search_utils_filters.py | 0 tests/unit/utils/__init__.py | 0 .../utils/maintenance/test_bulk_utils.py | 0 .../utils/maintenance/test_edge_operations.py | 0 .../utils/maintenance/test_node_operations.py | 0 .../maintenance/test_temporal_operations.py} | 0 .../utils/search/test_hybrid_search.py} | 0 tests/{ => unit/utils}/test_text_utils.py | 0 83 files changed, 990 insertions(+), 669 deletions(-) delete mode 100644 MCP_SMOKE_DEBUG.md create mode 100644 TEST_RESTRUCTURE_PLAN.md delete mode 100644 tests/driver/__init__.py create mode 100644 tests/e2e/__init__.py create mode 100644 tests/e2e/graph/__init__.py rename tests/{evals => e2e/graph}/data/longmemeval_data/README.md (100%) rename tests/{evals => e2e/graph}/data/longmemeval_data/longmemeval_oracle.json (100%) rename tests/{evals => e2e/graph}/eval_cli.py (94%) rename tests/{evals/eval_e2e_graph_building.py => e2e/graph/eval_graph_building.py} (88%) create mode 100644 tests/e2e/graph/test_eval_harness.py rename tests/{evals => e2e/graph}/utils.py (100%) delete mode 100644 tests/evals/pytest.ini create mode 100644 tests/helpers/__init__.py create mode 100644 tests/helpers/embeddings.py create mode 100644 tests/helpers/services.py delete mode 100644 tests/helpers_test.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/core/__init__.py create mode 100644 tests/integration/core/shared/__init__.py rename tests/{test_entity_exclusion_int.py => integration/core/shared/test_entity_exclusion.py} (93%) rename tests/{test_graphium_int.py => integration/core/shared/test_graphium_bootstrap.py} (97%) rename tests/{test_graphium_mock.py => integration/core/shared/test_ingestion_pipeline.py} (99%) rename tests/{test_edge_int.py => integration/core/shared/test_repository_edges.py} (88%) rename tests/{test_node_int.py => integration/core/shared/test_repository_nodes.py} (92%) create mode 100644 tests/integration/falkordb/__init__.py create mode 100644 tests/integration/falkordb/test_placeholder.py create mode 100644 tests/integration/kuzu/__init__.py create mode 100644 tests/integration/kuzu/test_placeholder.py rename tests/{llm_client/test_anthropic_client_int.py => integration/llm_client/test_anthropic_client.py} (100%) create mode 100644 tests/integration/neo4j/__init__.py create mode 100644 tests/integration/neo4j/test_placeholder.py create mode 100644 tests/integration/shared/__init__.py create mode 100644 tests/integration/shared/fixtures_services.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/cross_encoder/__init__.py rename tests/{ => unit}/cross_encoder/test_bge_reranker_client.py (100%) rename tests/{ => unit}/cross_encoder/test_gemini_reranker_client.py (100%) create mode 100644 tests/unit/drivers/__init__.py rename tests/{driver => unit/drivers}/test_falkordb_driver.py (100%) create mode 100644 tests/unit/embedder/__init__.py rename tests/{ => unit}/embedder/embedder_fixtures.py (100%) rename tests/{ => unit}/embedder/test_embeddinggemma.py (100%) rename tests/{ => unit}/embedder/test_gemini.py (99%) rename tests/{ => unit}/embedder/test_openai.py (98%) rename tests/{ => unit}/embedder/test_voyage.py (98%) create mode 100644 tests/unit/llm_client/__init__.py rename tests/{ => unit}/llm_client/test_anthropic_client.py (100%) rename tests/{ => unit}/llm_client/test_client.py (100%) rename tests/{ => unit}/llm_client/test_errors.py (100%) rename tests/{ => unit}/llm_client/test_gemini_client.py (100%) rename tests/{ => unit}/llm_client/test_groq_client.py (100%) rename tests/{ => unit}/llm_client/test_litellm_client.py (100%) rename tests/{ => unit}/llm_client/test_pydantic_ai_adapter.py (100%) rename tests/{ => unit}/llm_client/test_structured_output.py (100%) create mode 100644 tests/unit/mcp/__init__.py rename tests/{ => unit}/mcp/test_episode_queue.py (100%) create mode 100644 tests/unit/orchestration/__init__.py rename tests/{ => unit}/orchestration/test_bulk.py (100%) rename tests/{ => unit}/orchestration/test_bulk_serialization.py (100%) rename tests/{ => unit}/orchestration/test_episode_orchestrator.py (100%) rename tests/{test_graphium_factory_usage.py => unit/orchestration/test_initializer_factory.py} (100%) rename tests/{ => unit}/orchestration/test_node_operations_sequence.py (100%) create mode 100644 tests/unit/providers/__init__.py rename tests/{ => unit}/providers/test_factory.py (100%) create mode 100644 tests/unit/search/__init__.py rename tests/{ => unit}/search/test_edge_search_orchestration.py (100%) create mode 100644 tests/unit/search/test_lucene_utils.py rename tests/{ => unit}/search/test_search_filters.py (100%) rename tests/{ => unit}/search/test_search_helpers.py (100%) rename tests/{ => unit}/search/test_search_utils_edges.py (100%) rename tests/{ => unit}/search/test_search_utils_filters.py (100%) create mode 100644 tests/unit/utils/__init__.py rename tests/{ => unit}/utils/maintenance/test_bulk_utils.py (100%) rename tests/{ => unit}/utils/maintenance/test_edge_operations.py (100%) rename tests/{ => unit}/utils/maintenance/test_node_operations.py (100%) rename tests/{utils/maintenance/test_temporal_operations_int.py => unit/utils/maintenance/test_temporal_operations.py} (100%) rename tests/{utils/search/search_utils_test.py => unit/utils/search/test_hybrid_search.py} (100%) rename tests/{ => unit/utils}/test_text_utils.py (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 298cf96..a888e5f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,5 +41,5 @@ jobs: DISABLE_KUZU: '1' DISABLE_NEPTUNE: '1' run: | - uv run pytest tests/orchestration/test_bulk_serialization.py tests/search/test_search_utils_filters.py - uv run pytest tests/test_graphium_mock.py::test_add_episode_persists_nodes_and_edges + uv run pytest tests/unit/orchestration/test_bulk_serialization.py tests/unit/search/test_search_utils_filters.py + uv run pytest tests/integration/core/shared/test_ingestion_pipeline.py::test_add_episode_persists_nodes_and_edges diff --git a/MCP_SMOKE_DEBUG.md b/MCP_SMOKE_DEBUG.md deleted file mode 100644 index d9ffb49..0000000 --- a/MCP_SMOKE_DEBUG.md +++ /dev/null @@ -1,183 +0,0 @@ -# MCP Smoke Harness Debug Log - -## 2025-10-11 Update: Deterministic Pass - -Summary of the final, deterministic smoke run with `SEMAPHORE_LIMIT=10`: - -- Concurrency envelope and maintenance: - - Logged: `semaphore_gather: scheduling 24 tasks with concurrency=10` - - Logged: `semaphore_gather: completed 24 tasks` - - Logged: `Completed all maintenance queries` and `Indices and constraints ensured (~56–70ms)` - - Logged: `Graphium client initialized successfully` -- Health and MCP handshake: - - Health flipped to `status=ok` - - MCP `initialize` request returned HTTP 200 with JSON result - -Root causes and fixes: - -- Health loop bug (state import): modules imported state via `from .state import graphium_client` and similar. These are static bindings and did not reflect later `set_client(...)` updates, causing `/healthz` to report "Graphium client not initialized" even after initialization. - - Fix: switch to dynamic imports, `from . import state`, and reference `state.graphium_client`, `state.graphium_config`, etc. - - Files updated: `mcp_server/graphium_mcp/status.py`, `mcp_server/graphium_mcp/lifecycle.py`, `mcp_server/graphium_mcp/tools.py`, `mcp_server/graphium_mcp/queues.py`. -- Smoke script HTTP behavior: - - `/mcp` responds with a redirect; curl wasn’t following it and the URL lacked a trailing slash. - - Fix: use `MCP_URL=http://localhost:8000/mcp/` and `curl -L`. - - Also added jq/python fallbacks for JSON validation and removed early failure on the transient "not initialized" message. - -Instrumentation (enabled with `MCP_SMOKE_DEBUG=1`): -- Per-query IDs and timings; watchdog (20s default) with `SHOW TRANSACTIONS`/`SHOW INDEXES` snapshots. -- Concurrency envelope logs in `semaphore_gather`. -- Initialization timing for the maintenance block. -- Driver DEBUG logs for `neo4j`/`neo4j.bolt`. - -Run command used: -- `SMOKE_TAIL=1 SEMAPHORE_LIMIT=10 SMOKE_WATCHDOG_SECS=20 ./scripts/run_mcp_smoke.sh` - -Expected evidence in logs: -- `semaphore_gather: scheduling 24 tasks with concurrency=10` -- `semaphore_gather: completed 24 tasks` -- `Graphium client initialized successfully` -- `Health check reported status=ok` -- `MCP initialize request succeeded` - -Hardening suggestions: -- Consider `/healthz` returning HTTP 503 while not initialized (keeps current JSON body). -- Keep `GRAPHIUM_SERIALIZE_FULLTEXT=1` as a guarded fallback for environments with schema lock contention. - -Implemented: -- `/healthz` now returns HTTP 503 when not initialized and 200 when ready. Tests added in `mcp_server/tests/test_health_endpoint.py` to enforce this behavior and to prevent regressions in state handling and readiness semantics. - -## Execution Trace (from code inspection) -1. `graphium_mcp_server.py` parses config (`initialize_server`) then calls `run_server_lifecycle`. -2. `run_server_lifecycle` –> `initialize_graphium` ( `mcp_server/graphium_mcp/lifecycle.py:16-79` ). -3. `initialize_graphium` instantiates `Graphium`, then awaits `client.build_indices_and_constraints()`. -4. `Graphium.build_indices_and_constraints` delegates to `GraphMaintenanceOrchestrator.build_indices_and_constraints` –> `GraphMaintenanceService.build_indices_and_constraints` –> `graphium_core/orchestration/maintenance/graph_data_operations.build_indices_and_constraints`. -5. `build_indices_and_constraints` composes 24 Cypher statements (20 range indexes + 4 fulltext) and executes them via `semaphore_gather`. -6. `semaphore_gather` bounds concurrency with `asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)`. Runtime settings (`graphium_core/runtime.py`) pull `SEMAPHORE_LIMIT` from env. -7. Each `_execute_query` currently forwards to `driver.repositories.search.execute` → `Neo4jDriver.execute_query` → `AsyncGraphDatabase.execute_query` (returns an eager result immediately once the server acknowledges the statement). -8. When every coroutine returns, control unwinds back up to `initialize_graphium`, which records the client in `graphium_mcp.state` and health checks succeed (`collect_status` verifies connectivity). - -## Test vs Smoke Environment -- Integration tests (`tests/test_graphium_mock.py`, `tests/test_graphium_int.py`, etc.) drive the identical code path. They run against a Neo4j instance provisioned via `tests/helpers_test.py` with the default runtime semaphore limit (20). Because the limit is ≥ the number of range index statements, the gather call submits every query in one wave; each completes promptly, so initialization finishes and the tests pass. -- The smoke harness overrides the semaphore limit to 10 (`mcp_server/docker-compose.testing.yml`, `SEMAPHORE_LIMIT=${SEMAPHORE_LIMIT:-10}`). That means only ten index statements execute at a time. We see log output for the first batch, but initialization never reaches “Graphium client initialized successfully,” and `/healthz` remains stuck in the “not initialized” state. No exception propagates, so the health endpoint cannot surface an error message. -- Manual `cypher-shell` runs confirm Neo4j accepts the “missing” statements and that `SHOW INDEXES` reports them as ONLINE, so the scripts are blocked in our orchestrator layer—not by Neo4j rejecting the statements. - -## Verified Facts (historical vs. current) -- Historical (pre-fix): we observed apparent stalls while health remained in "not initialized" and assumed `_execute_query` was hanging under concurrency. -- Current (post-instrumentation and state-import fix): all 24 maintenance queries complete deterministically with `SEMAPHORE_LIMIT=10`; initialization succeeds within ~60–70ms; health flips to `ok` and MCP initialize returns a valid JSON result. -- No Neo4j lock/wait was detected during the instrumented runs; watchdog did not trigger. - -## Open Questions / Next Steps -1. Decide whether to return HTTP 503 for not-initialized health to better align with container orchestrators. -2. Keep `GRAPHIUM_SERIALIZE_FULLTEXT=1` as a safe-toggle in unusual Neo4j deployments (serialize only fulltext index creation during boot). -3. Consider a session-per-query variant for schema ops behind a switch if environments with unusual driver/server behavior materialize. - -## Known Bad Fixes (documented for posterity) -- Calling `await result.consume()` or `result.summary()` on `EagerResult` is incorrect—the async driver returns immediate data and exposes `.summary` as a property; accessing it does not unblock the coroutine and, in our trials, the added calls triggered `AttributeError`. -- Redirecting to `driver.execute_query` instead of the repository wrapper didn’t change behavior; the wrapper already delegates to the driver’s `execute_query`. - -## Current Working Theory -The hang is caused by issuing more concurrent schema-altering statements than Neo4j (and/or the async driver) can process simultaneously. With `SEMAPHORE_LIMIT=10`, a subset of index creation tasks never complete, possibly because Neo4j serializes them internally while our driver waits for a result stream that never arrives. Tests avoid the issue because they run with the default limit (20) and effectively submit the full batch in one go, which Neo4j handles eagerly. We still need hard evidence of the stuck task state to confirm this theory and design a robust fix. - -## Confirmed Root Cause (Health endpoint loop) - -- After instrumentation, all 24 maintenance queries complete quickly and the client initializes successfully. -- Health kept reporting "Graphium client not initialized" due to a Python module-level binding bug: modules imported state variables with `from .state import graphium_client`, so later `set_client(...)` reassignments in the state module were not visible in those modules. -- Fix: refactor imports to `from . import state` and reference `state.graphium_client` (and other state attributes) dynamically. - -Files updated to use dynamic state: -- `mcp_server/graphium_mcp/status.py:1` -- `mcp_server/graphium_mcp/lifecycle.py:1` -- `mcp_server/graphium_mcp/tools.py:1` -- `mcp_server/graphium_mcp/queues.py:1` - -With this fix, health returns `status=ok` after initialization. - -## Smoke Script Fixes - -- Follow HTTP redirect and require trailing slash: use `MCP_URL=http://localhost:8000/mcp/` and `curl -L`. -- Robust JSON parsing for the MCP response (jq/python3/python fallbacks) and stricter "has result" validation. -- Do not fail early on the transient "Graphium client not initialized" health message; keep polling until `ok` or timeout. -- Optional continuous container log streaming (`SMOKE_TAIL=1`). - -## Validated Outcome - -- After the above code and script changes, the smoke test consistently reports: - - `Health check reported status=ok` - - `MCP initialize request succeeded` - - `Smoke test finished successfully` - -## Instrumentation Added (collect hard data) - -- Query-level IDs and timings: every maintenance query now logs a unique `[qid]`, start, and completion with elapsed time. Slow/stuck queries are easier to correlate. -- Watchdog diagnostics: when `MCP_SMOKE_DEBUG=1`, a watchdog fires after `MCP_SMOKE_WATCHDOG_SECS` (default 30s) for any long-running maintenance query and logs: - - Top `SHOW TRANSACTIONS` entries (status, elapsed, query text) on Neo4j - - `SHOW INDEXES` counts by state (e.g., ONLINE vs total) - - A lightweight `asyncio` task count snapshot -- Concurrency envelope logging: `semaphore_gather` logs the batch size and the effective concurrency limit when `MCP_SMOKE_DEBUG=1`. -- Lifecycle timing: initialization logs how long index+constraint creation took in milliseconds. -- Driver logs: when `MCP_SMOKE_DEBUG=1`, the `neo4j` and `neo4j.bolt` loggers are set to DEBUG for wire-level context. - -Environment variables: -- `MCP_SMOKE_DEBUG=1` to enable all extra diagnostics -- `MCP_SMOKE_WATCHDOG_SECS=30` to change the watchdog trigger threshold -- `MCP_LOG_LEVEL=DEBUG|INFO|...` to override base log level -- `GRAPHIUM_SERIALIZE_FULLTEXT=1` to serialize fulltext index creation while keeping range indexes concurrent - -Relevant code references: -- `graphium_core/orchestration/maintenance/graph_data_operations.py:16` (watchdog and per-query logs) -- `graphium_core/utils/async_utils.py:1` (semaphore_gather diagnostics) -- `mcp_server/graphium_mcp/lifecycle.py:1` (initialization timing) -- `mcp_server/graphium_mcp_server.py:1` (log level + driver debug wiring) - -## Hypotheses And Tests (methodical plan) - -- Hypothesis A: Concurrency deadlock under lower `SEMAPHORE_LIMIT` causes some schema operations to stall. - - Test: Run with `SEMAPHORE_LIMIT=10` and `MCP_SMOKE_DEBUG=1`. Expect to see some queries without corresponding "Completed" logs. Watchdog should fire and report no active transactions in Neo4j; if so, the stall is client-side. - - Control: Run with `SEMAPHORE_LIMIT=24` then `SEMAPHORE_LIMIT=1`. Compare per-query timings and completion rates. - -- Hypothesis B: Certain schema statements (e.g., fulltext index creation) block longer than range indexes. - - Test: Group logs by `[qid]` and query text; confirm which categories trip the watchdog. If only fulltext queries are slow, serialize those while keeping range indexes concurrent. - -- Hypothesis C: Driver/session usage pattern under concurrent `execute_query` calls is exhausting a pool or awaiting a stream that never arrives. - - Test: Compare behavior when replacing the maintenance batch with a sequential loop (no semaphore) locally to confirm progress. If sequential runs complete quickly, concurrency is the trigger. - -## How To Reproduce And Capture Evidence - -Smoke (container, HTTP transport): - -- `cd mcp_server` -- `SEMAPHORE_LIMIT=10 MCP_SMOKE_DEBUG=1 MCP_SMOKE_WATCHDOG_SECS=20 docker compose -f docker-compose.testing.yml up --build` -- Observe logs. Expect to see: - - `semaphore_gather: scheduling 24 tasks with concurrency=10` - - `Executing maintenance query [qid]...` and matching `Completed...` lines - - If a query runs > 20s, look for `Watchdog: maintenance query [qid] still running...` followed by `SHOW TRANSACTIONS` and `SHOW INDEXES` summaries - -Direct run (no docker, for rapid iteration): - -- `cd mcp_server` -- `export NEO4J_URI=bolt://localhost:7687; export NEO4J_USER=neo4j; export NEO4J_PASSWORD=demodemo` -- `SEMAPHORE_LIMIT=10 MCP_SMOKE_DEBUG=1 uv run graphium_mcp_server.py --transport streamable-http` - -To verify Neo4j’s view manually: - -- `cypher-shell -a bolt://localhost:7687 -u "$NEO4J_USER" -p "$NEO4J_PASSWORD" "SHOW INDEXES"` -- `cypher-shell -a bolt://localhost:7687 -u "$NEO4J_USER" -p "$NEO4J_PASSWORD" "SHOW TRANSACTIONS"` - -## What To Look For In Logs - -- All 24 maintenance queries should emit paired start/complete lines with reasonable elapsed times. -- If a stall recurs: - - Note `[qid]` for the stuck query and its type (range vs fulltext) - - Check watchdog output: if `SHOW TRANSACTIONS` is empty/idle, it’s not a server-side lock - - Compare with `SHOW INDEXES` to see progression of index states (ONLINE vs POPULATING) - - Extract `neo4j.bolt` DEBUG snippets around the stuck `[qid]` - -## Contingency Fixes To Validate (toggleable) - -- Serialize fulltext index creation only (keep range indexes concurrent). If that removes stall, adopt targeted sequencing. -- Switch to a session-per-query `execute_write` pattern for schema ops to force transactional scoping and connection release after each statement. -- As a last resort for production reliability, run the maintenance block sequentially during boot (acceptable one-time cost) and keep a background job for periodic checks. - -## Decision Log (to be updated) - -- If Hypothesis A is confirmed and watchdog shows an idle server while client tasks hang, prioritize changing the submission pattern (sequential fulltext, or session-per-query) and keep concurrency for read-heavy paths. diff --git a/TEST_RESTRUCTURE_PLAN.md b/TEST_RESTRUCTURE_PLAN.md new file mode 100644 index 0000000..f4bd37b --- /dev/null +++ b/TEST_RESTRUCTURE_PLAN.md @@ -0,0 +1,442 @@ +# Graphium Test Restructure Plan + +## Goals + +1. **Separate unit vs. integration behaviour clearly.** Tiny unit tests should run without external services; integration suites should own database/driver dependencies and be opt‑in through markers or environment flags. +2. **Align test layout with module ownership.** Group tests by Graphium subsystem (orchestration, search, drivers, MCP, UI) so the intent is discoverable. +3. **Enable focused CI pipelines.** Keep a fast unit tier in the default workflow; provide reusable make/uv targets for the heavier integration checks. +4. **Eliminate ad‑hoc fixtures and environment leaks.** Centralise service fixtures (Neo4j, FalkorDB, Kùzu, Neptune) and make their lifecycle predictable. + +## Current Pain Points + +- `tests/test_graphium_mock.py` mixes mock‑heavy assertions with DB access and now fails in CI without manual `DISABLE_*` overrides. +- Service fixtures live in `tests/helpers_test.py`, which also registers drivers at import time; this makes selective execution fragile. +- Module grouping is inconsistent: some suites live under `tests/orchestration/`, others at the root (e.g., `test_edge_int.py`, `test_node_int.py`), and `tests/evals/` blends end‑to‑end runs with unit helpers. +- CI cannot distinguish low‑cost regression checks from integration jobs that need real GraphDB instances. + +## Proposed Directory Layout + +``` +tests/ +├─ unit/ +│ ├─ orchestration/ +│ │ ├─ test_bulk_serialization.py +│ │ ├─ test_ingestion_service.py +│ │ └─ … +│ ├─ search/ +│ ├─ mcp/ +│ ├─ utils/ +│ └─ … +├─ integration/ +│ ├─ neo4j/ +│ │ ├─ test_graphium_neo4j.py +│ │ └─ fixtures_neo4j.py +│ ├─ falkordb/ +│ ├─ kuzu/ +│ └─ shared/ +│ └─ fixtures_services.py +├─ e2e/ +│ └─ (long‑running graph build, eval harnesses) +└─ helpers/ + ├─ embeddings.py + ├─ factories.py + └─ markers.py +``` + +- **Unit tier** contains tests that can run with pure mocks/in-memory fixtures. +- **Integration tier** retains the existing service coverage but moves per‑provider suites behind pytest markers (`@pytest.mark.neo4j`, etc.). +- **E2E tier** is optional, only for the eval harness and smoke/regression scenarios that exercise CLI or the MCP server end-to-end. +- Shared helper modules move under `tests/helpers/` to avoid import side effects when the unit tier is collected. + +## Execution Strategy + +| Tier | Location | Marker / Env Flag | CI default | Command suggestion | +|-------------|-----------------------|-------------------|------------|------------------------------------------| +| Unit | `tests/unit` | `pytest -m "not integration and not e2e"` | ✅ | `uv run pytest tests/unit` | +| Integration | `tests/integration` | `@pytest.mark.integration` + provider markers | opt‑in | `uv run pytest -m "integration and neo4j"` | +| E2E | `tests/e2e` | `@pytest.mark.e2e` | manual | `uv run pytest -m e2e` | + +Markers to add in `pytest.ini`: + +``` +[pytest] +markers = + integration: tests that require external services + neo4j: requires a running Neo4j instance + falkordb: requires a running FalkorDB instance + kuzu: requires a running Kùzu instance + e2e: long-running end-to-end scenarios +``` + +## Detailed File Mapping + +### Integration Tests + +- **tests/integration/core/shared/test_community_operations.py** + - tests.test_graphium_mock::test_determine_entity_community + - tests.test_graphium_mock::test_get_community_clusters +- **tests/integration/core/shared/test_entity_exclusion.py** + - tests.test_entity_exclusion_int::test_exclude_all_types + - tests.test_entity_exclusion_int::test_exclude_default_entity_type + - tests.test_entity_exclusion_int::test_exclude_no_types + - tests.test_entity_exclusion_int::test_exclude_specific_custom_types + - tests.test_entity_exclusion_int::test_excluded_types_parameter_validation_in_add_episode + - tests.test_entity_exclusion_int::test_validation_invalid_excluded_types + - tests.test_entity_exclusion_int::test_validation_valid_excluded_types +- **tests/integration/core/shared/test_graphium_bootstrap.py** + - tests.test_graphium_int::test_graphium_init +- **tests/integration/core/shared/test_ingestion_pipeline.py** + - tests.test_graphium_mock::test_add_bulk + - tests.test_graphium_mock::test_add_episode_persists_nodes_and_edges + - tests.test_graphium_mock::test_filter_existing_duplicate_of_edges + - tests.test_graphium_mock::test_get_embeddings_for_communities + - tests.test_graphium_mock::test_get_embeddings_for_edges + - tests.test_graphium_mock::test_get_embeddings_for_nodes + - tests.test_graphium_mock::test_graphium_retrieve_episodes + - tests.test_graphium_mock::test_remove_episode +- **tests/integration/core/shared/test_repository_edges.py** + - tests.test_edge_int::test_community_edge + - tests.test_edge_int::test_entity_edge + - tests.test_edge_int::test_episodic_edge +- **tests/integration/core/shared/test_repository_nodes.py** + - tests.test_node_int::test_community_node + - tests.test_node_int::test_entity_node + - tests.test_node_int::test_episodic_node +- **tests/integration/core/shared/test_search_edges.py** + - tests.test_graphium_mock::test_edge_bfs_search + - tests.test_graphium_mock::test_edge_fulltext_search + - tests.test_graphium_mock::test_edge_similarity_search + - tests.test_graphium_mock::test_episode_mentions_reranker + - tests.test_graphium_mock::test_get_relevant_edges_and_invalidation_candidates +- **tests/integration/core/shared/test_search_nodes.py** + - tests.test_graphium_mock::test_community_fulltext_search + - tests.test_graphium_mock::test_community_similarity_search + - tests.test_graphium_mock::test_episode_fulltext_search + - tests.test_graphium_mock::test_get_communities_by_nodes + - tests.test_graphium_mock::test_get_mentioned_nodes + - tests.test_graphium_mock::test_get_relevant_nodes + - tests.test_graphium_mock::test_node_bfs_search + - tests.test_graphium_mock::test_node_distance_reranker + - tests.test_graphium_mock::test_node_fulltext_search + - tests.test_graphium_mock::test_node_similarity_search +- **tests/integration/cross_encoder/test_bge_reranker.py** + - tests.cross_encoder.test_bge_reranker_client::test_rank_basic_functionality + - tests.cross_encoder.test_bge_reranker_client::test_rank_empty_input + - tests.cross_encoder.test_bge_reranker_client::test_rank_single_passage +- **tests/integration/drivers/test_falkordb_driver.py** + - tests.driver.test_falkordb_driver::TestDatetimeConversion.test_convert_datetime_dict + - tests.driver.test_falkordb_driver::TestDatetimeConversion.test_convert_datetime_list_and_tuple + - tests.driver.test_falkordb_driver::TestDatetimeConversion.test_convert_other_types_unchanged + - tests.driver.test_falkordb_driver::TestDatetimeConversion.test_convert_single_datetime + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_close_calls_connection_close + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_delete_all_indexes + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_execute_query_converts_datetime_parameters + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_execute_query_handles_index_already_exists_error + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_execute_query_propagates_other_exceptions + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_execute_query_success + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_get_graph_with_name + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_get_graph_with_none_defaults_to_default_database + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_init_with_connection_params + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_init_with_falkor_db_instance + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_provider + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_session_creation + - tests.driver.test_falkordb_driver::TestFalkorDriver.test_session_creation_with_none_uses_default_database + - tests.driver.test_falkordb_driver::TestFalkorDriverIntegration.test_basic_integration_with_real_falkordb + - tests.driver.test_falkordb_driver::TestFalkorDriverSession.test_close_method + - tests.driver.test_falkordb_driver::TestFalkorDriverSession.test_execute_write_passes_session_and_args + - tests.driver.test_falkordb_driver::TestFalkorDriverSession.test_run_converts_datetime_objects_to_iso_strings + - tests.driver.test_falkordb_driver::TestFalkorDriverSession.test_run_propagates_exceptions +- **tests/integration/llm_client/test_anthropic_client.py** + - tests.llm_client.test_anthropic_client_int::test_extract_json_from_text + - tests.llm_client.test_anthropic_client_int::test_generate_simple_response + +### Unit Tests + +- **tests/unit/core/maintenance/test_bulk_utils.py** + - tests.utils.maintenance.test_bulk_utils::test_build_directed_uuid_map_chain + - tests.utils.maintenance.test_bulk_utils::test_build_directed_uuid_map_empty + - tests.utils.maintenance.test_bulk_utils::test_build_directed_uuid_map_preserves_direction + - tests.utils.maintenance.test_bulk_utils::test_candidate_edges_for_uses_semantic_similarity + - tests.utils.maintenance.test_bulk_utils::test_collect_edge_candidates_filters_by_endpoints + - tests.utils.maintenance.test_bulk_utils::test_dedupe_edges_bulk_deduplicates_within_episode + - tests.utils.maintenance.test_bulk_utils::test_dedupe_nodes_bulk_handles_empty_batch + - tests.utils.maintenance.test_bulk_utils::test_dedupe_nodes_bulk_missing_canonical_falls_back + - tests.utils.maintenance.test_bulk_utils::test_dedupe_nodes_bulk_reuses_canonical_nodes + - tests.utils.maintenance.test_bulk_utils::test_dedupe_nodes_bulk_single_episode + - tests.utils.maintenance.test_bulk_utils::test_dedupe_nodes_bulk_uuid_map_respects_direction + - tests.utils.maintenance.test_bulk_utils::test_find_exact_name_match_handles_case + - tests.utils.maintenance.test_bulk_utils::test_merge_canonical_nodes_detects_exact_match + - tests.utils.maintenance.test_bulk_utils::test_resolve_edge_pointers_updates_sources +- **tests/unit/core/maintenance/test_edge_operations.py** + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_invalidates_older_edges + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_inserts_new_fact_when_no_duplicates + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_invalidates_matching_edges + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_invalidates_partial_duplicates + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_removes_inconsistent_facts + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_updates_expired_edges + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_updates_fact + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_updates_validity_window + - tests.utils.maintenance.test_edge_operations::test_apply_invalidation_policy_updates_with_new_fact + - tests.utils.maintenance.test_edge_operations::test_convert_extracted_edges_to_entities_filters_blank_facts + - tests.utils.maintenance.test_edge_operations::test_convert_extracted_edges_to_entities_logs_invalid_indices + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edge_accepts_unknown_fact_type + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edge_exact_fact_short_circuit + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edge_rejects_unmapped_fact_type + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edge_uses_integer_indices_for_duplicates + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edges_fast_path_deduplication + - tests.utils.maintenance.test_edge_operations::test_resolve_extracted_edges_keeps_unknown_names +- **tests/unit/core/maintenance/test_node_resolution.py** + - tests.utils.maintenance.test_node_operations::test_collect_candidate_nodes_dedupes_and_merges_override + - tests.utils.maintenance.test_node_operations::test_extract_attributes_from_nodes_with_callback + - tests.utils.maintenance.test_node_operations::test_extract_attributes_with_callback_generate_summary + - tests.utils.maintenance.test_node_operations::test_extract_attributes_with_callback_skip_summary + - tests.utils.maintenance.test_node_operations::test_extract_attributes_with_selective_callback + - tests.utils.maintenance.test_node_operations::test_extract_attributes_with_selective_callback_override_summary + - tests.utils.maintenance.test_node_operations::test_extract_attributes_without_callback_generates_summary + - tests.utils.maintenance.test_node_operations::test_has_high_entropy_rules + - tests.utils.maintenance.test_node_operations::test_hash_minhash_and_lsh + - tests.utils.maintenance.test_node_operations::test_jaccard_similarity_edges + - tests.utils.maintenance.test_node_operations::test_materialize_extracted_entities_respects_exclusions + - tests.utils.maintenance.test_node_operations::test_materialize_extracted_entities_sets_attribute_model + - tests.utils.maintenance.test_node_operations::test_name_entropy_variants + - tests.utils.maintenance.test_node_operations::test_normalize_helpers + - tests.utils.maintenance.test_node_operations::test_resolve_nodes_exact_match_skips_llm + - tests.utils.maintenance.test_node_operations::test_resolve_nodes_fuzzy_match + - tests.utils.maintenance.test_node_operations::test_resolve_nodes_low_entropy_uses_llM + - tests.utils.maintenance.test_node_operations::test_resolve_with_llm_ignores_duplicate_relative_ids + - tests.utils.maintenance.test_node_operations::test_resolve_with_llm_ignores_out_of_range_relative_ids + - tests.utils.maintenance.test_node_operations::test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted + - tests.utils.maintenance.test_node_operations::test_resolve_with_llm_updates_unresolved + - tests.utils.maintenance.test_node_operations::test_resolve_with_similarity_exact_match_updates_state + - tests.utils.maintenance.test_node_operations::test_resolve_with_similarity_low_entropy_defers_resolution + - tests.utils.maintenance.test_node_operations::test_resolve_with_similarity_multiple_exact_matches_defers_to_llm + - tests.utils.maintenance.test_node_operations::test_shingles_and_cache + - tests.utils.maintenance.test_node_operations::test_signature_dtype_guard +- **tests/unit/core/maintenance/test_temporal_operations.py** + - tests.utils.maintenance.test_temporal_operations_int::test_get_edge_contradictions + - tests.utils.maintenance.test_temporal_operations_int::test_get_edge_contradictions_multiple_existing + - tests.utils.maintenance.test_temporal_operations_int::test_get_edge_contradictions_no_contradictions + - tests.utils.maintenance.test_temporal_operations_int::test_get_edge_contradictions_no_effect + - tests.utils.maintenance.test_temporal_operations_int::test_get_edge_contradictions_temporal_update + - tests.utils.maintenance.test_temporal_operations_int::test_invalidate_edges_complex + - tests.utils.maintenance.test_temporal_operations_int::test_invalidate_edges_partial_update +- **tests/unit/core/orchestration/test_bulk_persistence.py** + - tests.orchestration.test_bulk::test_persist_bulk_payloads_wrapps_sequences_for_graph_operations + - tests.orchestration.test_bulk::test_serialize_entity_edges + - tests.orchestration.test_bulk::test_serialize_entity_nodes + - tests.orchestration.test_bulk::test_serialize_episodes_converts_source_enum +- **tests/unit/core/orchestration/test_bulk_serialization.py** + - tests.orchestration.test_bulk_serialization::test_serialize_episodic_edge_payload + - tests.orchestration.test_bulk_serialization::test_serialize_episodic_edge_payload_handles_missing_embedding + - tests.orchestration.test_bulk_serialization::test_serialize_entity_edge_payload + - tests.orchestration.test_bulk_serialization::test_serialize_entity_node_payload + - tests.orchestration.test_bulk_serialization::test_serialize_episode_payload +- **tests/unit/core/orchestration/test_episode_orchestrator.py** + - tests.orchestration.test_episode_orchestrator::test_merge_edge_type_map_accepts_sequence_signature + - tests.orchestration.test_episode_orchestrator::test_merge_edge_type_map_rejects_invalid_signatures +- **tests/unit/core/orchestration/test_initializer_factory.py** + - tests.test_graphium_factory_usage::test_graphium_invokes_reranker_factory +- **tests/unit/core/orchestration/test_node_operations_sequence.py** + - tests.orchestration.test_node_operations_sequence::test_collect_candidate_nodes_invocations + - tests.orchestration.test_node_operations_sequence::test_resolve_extracted_nodes_accepts_any_sequence +- **tests/unit/core/providers/test_factory.py** + - tests.providers.test_factory::test_create_embedder_from_settings + - tests.providers.test_factory::test_create_llm_client_from_settings + - tests.providers.test_factory::test_create_reranker_from_llm_settings +- **tests/unit/core/search/test_edge_search_orchestration.py** + - tests.search.test_edge_search_orchestration::test_edge_search_bfs_seeded_from_results + - tests.search.test_edge_search_orchestration::test_edge_search_cross_encoder + - tests.search.test_edge_search_orchestration::test_edge_search_rrF_only +- **tests/unit/core/search/test_hybrid_search.py** + - tests.utils.search.search_utils_test::test_hybrid_node_search_delegates_to_similarity_and_fulltext + - tests.utils.search.search_utils_test::test_hybrid_node_search_handles_missing_results + - tests.utils.search.search_utils_test::test_hybrid_node_search_merges_scores + - tests.utils.search.search_utils_test::test_hybrid_node_search_returns_nodes +- **tests/unit/core/search/test_lucene_utils.py** + - tests.helpers_test::test_lucene_sanitize +- **tests/unit/core/search/test_search_filters.py** + - tests.search.test_search_filters::test_build_date_filter_clause + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_builds_filters + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_handles_dates + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_handles_labels + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_handles_uuid filters + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_returns_empty_lists + - tests.search.test_search_filters::test_node_search_filter_query_constructor_builds_filters +- **tests/unit/core/search/test_search_helpers.py** + - tests.search.test_search_helpers::test_build_search_config_handles_cross_encoder_weight + - tests.search.test_search_helpers::test_build_search_config_sets_defaults + - tests.search.test_search_helpers::test_build_search_config_validates_weights + - tests.search.test_search_helpers::test_rescore_with_cross_encoder_handles_empty + - tests.search.test_search_helpers::test_rescore_with_cross_encoder_sorts_results +- **tests/unit/core/search/test_search_utils_edges.py** + - tests.search.test_search_utils_edges::test_get_edge_invalidation_candidates_default_provider + - tests.search.test_search_utils_edges::test_get_relevant_edges_default_provider + - tests.search.test_search_utils_edges::test_node_distance_reranker +- **tests/unit/core/search/test_search_utils_filters.py** + - tests.search.test_search_utils_filters::test_build_edge_filter_clause_with_group_and_endpoints + - tests.search.test_search_utils_filters::test_build_edge_filter_clause_without_filters + - tests.search.test_search_utils_filters::test_collect_edge_matches_ignores_missing_uuid + - tests.search.test_search_utils_filters::test_fulltext_query_default_provider_includes_group_filter + - tests.search.test_search_utils_filters::test_fulltext_query_falkordb_delegates + - tests.search.test_search_utils_filters::test_fulltext_query_kuzu_respects_max_length +- **tests/unit/embedder/test_embeddinggemma.py** + - tests.embedder.test_embeddinggemma::test_embeddinggemma_create +- **tests/unit/embedder/test_gemini.py** + - tests.embedder.test_gemini::test_gemini_embedding_client_handles_rate_limits + - tests.embedder.test_gemini::test_gemini_embedding_client_initialization_defaults + - tests.embedder.test_gemini::test_gemini_embedding_client_parses_response +- **tests/unit/embedder/test_openai.py** + - tests.embedder.test_openai::test_openai_embedder_creates_embeddings + - tests.embedder.test_openai::test_openai_embedder_handles_rate_limit +- **tests/unit/embedder/test_voyage.py** + - tests.embedder.test_voyage::test_voyage_embedder_batches_inputs + - tests.embedder.test_voyage::test_voyage_embedder_handles_http_error +- **tests/unit/llm_client/test_anthropic_client.py** + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_create_tool + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_extract_json_from_text + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_generate_response_with_text_response + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_generate_response_with_tool_use + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_rate_limit_error + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_refusal_error + - tests.llm_client.test_anthropic_client::TestAnthropicClientGenerateResponse.test_validation_error_retry + - tests.llm_client.test_anthropic_client::TestAnthropicClientInitialization.test_init_with_config + - tests.llm_client.test_anthropic_client::TestAnthropicClientInitialization.test_init_with_custom_client + - tests.llm_client.test_anthropic_client::TestAnthropicClientInitialization.test_init_with default_model + - tests.llm_client.test_anthropic_client::TestAnthropicClientInitialization.test_init_without_config +- **tests/unit/llm_client/test_client.py** + - tests.llm_client.test_client::test_client_calls_generate_response + - tests.llm_client.test_client::test_client_handles_structured_output + - tests.llm_client.test_client::test_client_raises_empty_response_error +- **tests/unit/llm_client/test_errors.py** + - tests.llm_client.test_errors::TestEmptyResponseError.test_message_assignment + - tests.llm_client.test_errors::TestEmptyResponseError.test_message_required + - tests.llm_client.test_errors::TestRateLimitError.test_custom_message + - tests.llm_client.test_errors::TestRateLimitError.test_default_message + - tests.llm_client.test_errors::TestRefusalError.test_message_assignment + - tests.llm_client.test_errors::TestRefusalError.test_message_required +- **tests/unit/llm_client/test_gemini_client.py** + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_custom_max_tokens + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_empty_response_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_gemini_model_max_tokens_mapping + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_generate_response_simple text + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_generate_response_with_structured_output + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_generate_response_with_system_message + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_get_model_for_size + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_max_retries_exceeded + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_max_tokens_precedence_fallback + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_model_size_selection + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_prompt_block_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_quota_error_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_rate_limit_error_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_resource_exhausted_error_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_retry_logic_with_safety_block + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_retry_logic_with_validation error + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_safety_block_handling + - tests.llm_client.test_gemini_client::TestGeminiClientGenerateResponse.test_structured_output_parsing_error + - tests.llm_client.test_gemini_client::TestGeminiClientInitialization.test_init_with_config + - tests.llm_client.test_gemini_client::TestGeminiClientInitialization.test_init_with_default_model + - tests.llm_client.test_gemini_client::TestGeminiClientInitialization.test_init_with_thinking_config + - tests.llm_client.test_gemini_client::TestGeminiClientInitialization.test_init_without_config +- **tests/unit/llm_client/test_groq_client.py** + - tests.llm_client.test_groq_client::test_generate_response_returns_json + - tests.llm_client.test_groq_client::test_generate_response_salvages_json + - tests.llm_client.test_groq_client::test_generate_response_validates_model + - tests.llm_client.test_groq_client::test_rate_limit_error +- **tests/unit/llm_client/test_litellm_client.py** + - tests.llm_client.test_litellm_client::test_litellm_client_prefers_pydantic_ai + - tests.llm_client.test_litellm_client::test_litellm_client_raises_rate_limit_error + - tests.llm_client.test_litellm_client::test_litellm_client_reports_json_repair + - tests.llm_client.test_litellm_client::test_litellm_client_retries_on_json_error + - tests.llm_client.test_litellm_client::test_litellm_client_returns_json + - tests.llm_client.test_litellm_client::test_litellm_client_validates_response_model +- **tests/unit/llm_client/test_pydantic_ai_adapter.py** + - tests.llm_client.test_pydantic_ai_adapter::test_pydantic_ai_adapter_multi_turn + - tests.llm_client.test_pydantic_ai_adapter::test_pydantic_ai_adapter_requires_user_prompt +- **tests/unit/llm_client/test_structured_output.py** + - tests.llm_client.test_structured_output::test_format_structured_retry_message_json_error + - tests.llm_client.test_structured_output::test_format_structured_retry_message_validation_error + - tests.llm_client.test_structured_output::test_salvage_json_response_returns_none + - tests.llm_client.test_structured_output::test_salvage_json_response_truncated_object +- **tests/unit/mcp/test_episode_queue.py** + - tests.mcp.test_episode_queue::test_enqueue_episode_retries_and_records_failures +- **tests/unit/search/test_edge_search_orchestration.py** + - tests.search.test_edge_search_orchestration::test_edge_search_bfs_seeded_from_results + - tests.search.test_edge_search_orchestration::test_edge_search_cross_encoder + - tests.search.test_edge_search_orchestration::test_edge_search_rrF_only +- **tests/unit/search/test_hybrid_search.py** + - tests.utils.search.search_utils_test::test_hybrid_node_search_delegates_to_similarity_and_fulltext + - tests.utils.search.search_utils_test::test_hybrid_node_search_handles_missing results + - tests.utils.search.search_utils_test::test_hybrid_node_search_merges_scores + - tests.utils.search.search_utils_test::test_hybrid_node_search_returns_nodes +- **tests/unit/search/test_lucene_utils.py** + - tests.helpers_test::test_lucene_sanitize +- **tests/unit/search/test_search_filters.py** + - tests.search.test_search_filters::test_build_date_filter_clause + - tests.search.test_search_filters::test_edge_search_filter_query_constructor_builds filters + - tests.search.test_search_filters::test_edge_search_filter query_constructor_handles_dates + - tests.search.test_search_filters::test_edge_search_filter query_constructor_handles labels + - tests.search.test_search_filters::test_edge_search_filter query_constructor_handles_uuid filters + - tests.search.test_search_filters::test_edge_search_filter query_constructor_returns empty lists + - tests.search.test_search_filters::test_node_search_filter query_constructor_builds filters +- **tests/unit/search/test_search_helpers.py** + - tests.search.test_search_helpers::test_build_search_config_handles_cross_encoder_weight + - tests.search.test_search_helpers::test_build_search_config_sets defaults + - tests.search.test_search_helpers::test_build_search_config_validates weights + - tests.search.test_search_helpers::test_rescore with cross_encoder_handles empty + - tests.search.test_search_helpers::test_rescore with cross_encoder_sorts results +- **tests/unit/search/test_search_utils_edges.py** + - tests.search.test_search_utils_edges::test_get_edge_invalidation_candidates default provider + - tests.search.test_search_utils_edges::test_get_relevant_edges default provider + - tests.search.test_search_utils_edges::test_node_distance_reranker +- **tests/unit/search/test_search_utils_filters.py** + - tests.search.test_search_utils_filters::test_build_edge_filter_clause_with group and endpoints + - tests.search.test_search_utils_filters::test_build_edge filter_clause without filters + - tests.search.test_search_utils_filters::test_collect_edge matches_ignores missing uuid + - tests.search.test_search_utils_filters::test_fulltext_query default provider includes group filter + - tests.search.test_search_utils_filters::test_fulltext_query falkordb delegates + - tests.search.test_search_utils_filters::test_fulltext_query kuzu respects max length +- **tests/unit/utils/test_text_utils.py** + - tests.test_text_utils::test_max_summary_chars_constant + - tests.test_text_utils::test_truncate_at_sentence empty + - tests.test_text_utils::test_truncate_at_sentence exact length + - tests.test_text_utils::test_truncate_at_sentence multiple periods + - tests.test_text_utils::test_truncate at sentence no boundary + - tests.test_text_utils::test_truncate at sentence realistic summary + - tests.test_text_utils::test_truncate at sentence short text + - tests.test_text_utils::test_truncate at sentence strips trailing whitespace + - tests.test_text_utils::test_truncate at sentence with exclamation + - tests.test_text_utils::test_truncate at sentence with period + - tests.test_text_utils::test_truncate at sentence with question + +### End-to-End Tests + +- The current `tests/evals/` modules contain helper functions and CLI entry points but no pytest-collected tests. During restructure, convert these scripts into explicit `tests/e2e/graph/test_eval_cli.py` and `tests/e2e/graph/test_eval_graph_building.py` modules with `@pytest.mark.e2e` wrappers around existing logic, or leave them as manual harnesses documented outside pytest. + +## Migration Roadmap + +1. **Scaffold directories** (`tests/unit`, `tests/integration`, `tests/e2e`, `tests/helpers`). +2. **Extract helper utilities** from `tests/helpers_test.py` into modular fixtures: + - `helpers/factories.py` for fake embeddings/nodes. + - `helpers/services.py` for service setup/teardown, parameterised by provider. +3. **Split `test_graphium_mock.py`:** + - Unit cases into `tests/unit/orchestration/test_graphium_episode.py`. + - Service-backed cases (Neo4j/FalkorDB/Kùzu) into corresponding integration files. +4. **Move eval scripts** (`tests/evals/`) under `tests/e2e/` with explicit markers. +5. **Update CI workflow** to execute the unit suite only; provide reusable GitHub Action (reusable workflow or manual dispatch) for integration tiers with secrets for Neo4j credentials. +6. **Document local integration commands** in `docs/development.md` (e.g., `uv run pytest -m "integration and neo4j" --maxfail=1`). +7. **Introduce pre-commit hook or nox session** that mirrors CI unit tier for contributors. +8. **Enable per-provider integration pipelines** once fixtures stabilise (start with Neo4j, then FalkorDB/Kùzu). + +## Service Provisioning Notes + +- Neo4j: Provide docker-compose override in `docker-compose.test.yml` and add `make integration-neo4j` target. +- FalkorDB/Kùzu: Document optional usage; default to skipped unless environment variables (`FALKORDB_URI`, `KUZU_DB_PATH`) are supplied. +- Redis (for FalkorDB cluster detection): Stub FalkorDB client in unit tests; integration suite can spin up ephemeral Redis via docker compose. + +## Next Steps + +1. Review and sign off on the file-by-file mapping above (especially the large `test_graphium_mock.py` split). +2. Implement Phase 1 (directory scaffold, helper extraction, unit CI tweaks). +3. Convert ingestion/search suites per provider (Phase 2) and add markers. +4. Move eval harness into `tests/e2e` and document opt-in execution (Phase 3). +5. Update contributor documentation once restructuring lands. diff --git a/conftest.py b/conftest.py index e3a1267..8872ed2 100644 --- a/conftest.py +++ b/conftest.py @@ -12,9 +12,9 @@ # Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests. sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__)))) -from tests.helpers_test import graph_driver, mock_embedder +from tests.helpers.embeddings import make_mock_embedder -__all__ = ['graph_driver', 'mock_embedder'] +__all__ = ['mock_embedder'] _original_client_session = aiohttp.ClientSession _open_aiohttp_sessions: set[aiohttp.ClientSession] = set() @@ -39,6 +39,11 @@ async def _close_session(self, *close_args, **close_kwargs): # type: ignore[unu aiohttp.connector.BaseConnector._warn_unclosed = lambda self, *args, **kwargs: None # type: ignore[attr-defined, assignment] +@pytest.fixture +def mock_embedder(): + return make_mock_embedder() + + @pytest.fixture(scope='session', autouse=True) def _patch_aiohttp_client_session(): yield diff --git a/docs/testing.md b/docs/testing.md index 3a360e1..fa44e23 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -18,9 +18,9 @@ uv sync --extra dev ## Test Layout -- Unit tests live under `tests/` and do not require external services. -- Integration tests use a real graph backend (Neo4j/FalkorDB/Kuzu/Neptune) and - are marked with `@pytest.mark.integration`. +- Unit tests live under `tests/unit` and run entirely with mocks or in-memory fixtures. +- Integration tests live under `tests/integration`, hit real graph backends (Neo4j/FalkorDB/Kuzu/Neptune), and are marked with `@pytest.mark.integration` plus provider-specific markers. +- End-to-end harnesses sit under `tests/e2e` and are opt-in (`@pytest.mark.e2e`), primarily covering long-running evaluation flows. ## Quick Start (Neo4j only) @@ -50,14 +50,19 @@ DISABLE_NEPTUNE=1 3) Run a specific integration test: ```bash -uv run pytest -q tests/test_edge_int.py::test_community_edge +uv run pytest -q tests/integration/core/shared/test_repository_edges.py::test_community_edge ``` -Run all non‑integration tests (fast path): +Run all non-integration tests (fast path): ```bash -DISABLE_NEO4J=1 DISABLE_FALKORDB=1 DISABLE_KUZU=1 DISABLE_NEPTUNE=1 \ - uv run pytest -m "not integration" +uv run pytest -m "not integration and not e2e" +``` + +Run the Neo4j-backed integration tier: + +```bash +uv run pytest -m "integration and neo4j" ``` ## Environment Resolution @@ -87,8 +92,8 @@ DISABLE_NEO4J=1 DISABLE_FALKORDB=1 DISABLE_KUZU=1 DISABLE_NEPTUNE=1 \ ## Useful Commands - List tests: `uv run pytest -q` -- Single file: `uv run pytest tests/test_edge_int.py -q` -- Single test: `uv run pytest tests/test_edge_int.py::test_entity_edge -q` +- Single file: `uv run pytest tests/integration/core/shared/test_repository_edges.py -q` +- Single test: `uv run pytest tests/integration/core/shared/test_repository_edges.py::test_entity_edge -q` - With logs: `uv run pytest -q -o log_cli=true -o log_cli_level=INFO` - Coverage + reports: `uv run pytest` (writes HTML to `docs/reports/coverage/`) - Complexity snapshot: `./scripts/generate_complexity_report.sh` (updates `docs/reports/code-complexity.md`) diff --git a/pytest.ini b/pytest.ini index 0170e13..4c37c81 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,11 @@ [pytest] markers = - integration: marks tests as integration tests + integration: tests that require external services + neo4j: requires a running Neo4j instance + falkordb: requires a running FalkorDB instance + kuzu: requires a running Kuzu instance + neptune: requires a running Neptune instance + e2e: long-running end-to-end scenarios asyncio_default_fixture_loop_scope = function asyncio_mode = auto addopts = diff --git a/tests/driver/__init__.py b/tests/driver/__init__.py deleted file mode 100644 index b5d9d58..0000000 --- a/tests/driver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for database drivers.""" diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/e2e/graph/__init__.py b/tests/e2e/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/evals/data/longmemeval_data/README.md b/tests/e2e/graph/data/longmemeval_data/README.md similarity index 100% rename from tests/evals/data/longmemeval_data/README.md rename to tests/e2e/graph/data/longmemeval_data/README.md diff --git a/tests/evals/data/longmemeval_data/longmemeval_oracle.json b/tests/e2e/graph/data/longmemeval_data/longmemeval_oracle.json similarity index 100% rename from tests/evals/data/longmemeval_data/longmemeval_oracle.json rename to tests/e2e/graph/data/longmemeval_data/longmemeval_oracle.json diff --git a/tests/evals/eval_cli.py b/tests/e2e/graph/eval_cli.py similarity index 94% rename from tests/evals/eval_cli.py rename to tests/e2e/graph/eval_cli.py index 39dd615..a0b3a39 100644 --- a/tests/evals/eval_cli.py +++ b/tests/e2e/graph/eval_cli.py @@ -1,7 +1,7 @@ import argparse import asyncio -from tests.evals.eval_e2e_graph_building import build_baseline_graph, eval_graph +from tests.e2e.graph.eval_graph_building import build_baseline_graph, eval_graph async def main(): diff --git a/tests/evals/eval_e2e_graph_building.py b/tests/e2e/graph/eval_graph_building.py similarity index 88% rename from tests/evals/eval_e2e_graph_building.py rename to tests/e2e/graph/eval_graph_building.py index 9be0a28..dee00cf 100644 --- a/tests/evals/eval_e2e_graph_building.py +++ b/tests/e2e/graph/eval_graph_building.py @@ -3,6 +3,7 @@ import json from datetime import UTC, datetime +from pathlib import Path import pandas as pd @@ -12,8 +13,15 @@ from graphium_core.nodes import EpisodeType from graphium_core.prompts import prompt_library from graphium_core.prompts.eval import EvalAddEpisodeResults +from graphium_core.settings import Neo4jSettings from graphium_core.utils.async_utils import semaphore_gather -from tests.test_graphium_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER + +BASE_DIR = Path(__file__).parent +DATA_DIR = BASE_DIR / 'data' +_neo4j_settings = Neo4jSettings() +NEO4J_URI = _neo4j_settings.uri +NEO4J_USER = _neo4j_settings.user +NEO4J_PASSWORD = _neo4j_settings.password async def build_subgraph( @@ -61,10 +69,8 @@ async def build_graph( group_id_suffix: str, multi_session_count: int, session_length: int, graphium: Graphium ) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]: # Get longmemeval dataset - lme_dataset_option = ( - 'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m - ) - lme_dataset_df = pd.read_json(lme_dataset_option) + lme_dataset_path = DATA_DIR / 'longmemeval_data' / 'longmemeval_oracle.json' + lme_dataset_df = pd.read_json(lme_dataset_path) add_episode_results: dict[str, list[AddEpisodeResults]] = {} add_episode_context: dict[str, list[str]] = {} @@ -92,13 +98,13 @@ async def build_graph( async def build_baseline_graph(multi_session_count: int, session_length: int): # Use gpt-4.1-mini for graph building baseline llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) - graphium = Graphium(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) + graphium = Graphium(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD, llm_client=llm_client) add_episode_results, _ = await build_graph( 'baseline', multi_session_count, session_length, graphium ) - filename = 'baseline_graph_results.json' + filename = BASE_DIR / 'baseline_graph_results.json' serializable_baseline_graph_results = { key: [item.model_dump(mode='json') for item in value] @@ -112,8 +118,8 @@ async def build_baseline_graph(multi_session_count: int, session_length: int): async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float: if llm_client is None: llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini')) - graphium = Graphium(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client) - with open('baseline_graph_results.json') as file: + graphium = Graphium(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD, llm_client=llm_client) + with open(BASE_DIR / 'baseline_graph_results.json') as file: baseline_results_raw = json.load(file) baseline_results: dict[str, list[AddEpisodeResults]] = { @@ -124,7 +130,7 @@ async def eval_graph(multi_session_count: int, session_length: int, llm_client=N 'candidate', multi_session_count, session_length, graphium ) - filename = 'candidate_graph_results.json' + filename = BASE_DIR / 'candidate_graph_results.json' candidate_baseline_graph_results = { key: [item.model_dump(mode='json') for item in value] diff --git a/tests/e2e/graph/test_eval_harness.py b/tests/e2e/graph/test_eval_harness.py new file mode 100644 index 0000000..c8bd08e --- /dev/null +++ b/tests/e2e/graph/test_eval_harness.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +import pytest + + +@pytest.mark.e2e +def test_eval_harness_placeholder(): + pytest.skip('Eval harness requires manual orchestration and external services.') diff --git a/tests/evals/utils.py b/tests/e2e/graph/utils.py similarity index 100% rename from tests/evals/utils.py rename to tests/e2e/graph/utils.py diff --git a/tests/evals/pytest.ini b/tests/evals/pytest.ini deleted file mode 100644 index 37735bf..0000000 --- a/tests/evals/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -asyncio_default_fixture_loop_scope = function -markers = - integration: marks tests as integration tests \ No newline at end of file diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/embeddings.py b/tests/helpers/embeddings.py new file mode 100644 index 0000000..8a23a12 --- /dev/null +++ b/tests/helpers/embeddings.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +"""Embedding helpers reused across integration and unit suites.""" + +from __future__ import annotations + +from typing import Iterable +from unittest.mock import Mock + +import numpy as np + +from graphium_core.embedder.client import EmbedderClient + +EMBEDDING_DIM = 384 + + +def default_embeddings() -> dict[str, list[float]]: + keys = [ + 'Alice', + 'Bob', + 'Alice likes Bob', + 'test_entity_1', + 'test_entity_2', + 'test_entity_3', + 'test_entity_4', + 'test_entity_alice', + 'test_entity_bob', + 'test_entity_1 is a duplicate of test_entity_2', + 'test_entity_3 is a duplicate of test_entity_4', + 'test_entity_1 relates to test_entity_2', + 'test_entity_1 relates to test_entity_3', + 'test_entity_2 relates to test_entity_3', + 'test_entity_1 relates to test_entity_4', + 'test_entity_2 relates to test_entity_4', + 'test_entity_3 relates to test_entity_4', + 'test_entity_2 relates to test_entity_4', + 'test_entity_2 relates to test_entity_3', + 'test_community_1', + 'test_community_2', + ] + embeddings = { + key: np.random.uniform(0.0, 0.9, EMBEDDING_DIM).tolist() for key in keys + } + embeddings['Alice Smith'] = embeddings['Alice'] + return embeddings + + +def make_mock_embedder(embedding_lookup: dict[str, list[float]] | None = None) -> Mock: + """Build a mock embedder compatible with Graphium helpers.""" + store = embedding_lookup or default_embeddings() + mock_model = Mock(spec=EmbedderClient) + + def mock_embed(input_data: str | Iterable[str]) -> list[float]: + if isinstance(input_data, str): + return store[input_data] + combined_input = ' '.join(input_data) + return store[combined_input] + + mock_model.create.side_effect = mock_embed + return mock_model diff --git a/tests/helpers/services.py b/tests/helpers/services.py new file mode 100644 index 0000000..b19ef56 --- /dev/null +++ b/tests/helpers/services.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +"""Shared helpers for Graphium integration tests. + +These utilities centralise driver provisioning and common repository helpers so +the individual test modules no longer need to import heavy fixtures directly. +Nothing in this module triggers driver imports at import time; providers are +resolved lazily to keep unit collections lightweight. +""" + +from __future__ import annotations + +import os +from collections.abc import Iterable +from typing import Callable + +import numpy as np + +from dotenv import load_dotenv + +from graphium_core.driver.driver import GraphDriver, GraphProvider +from graphium_core.edges import CommunityEdge, EntityEdge, EpisodicEdge +from graphium_core.nodes import CommunityNode, EntityNode, EpisodicNode +from graphium_core.orchestration.maintenance.graph_data_operations import clear_data +from graphium_core.settings import ( + FalkorSettings, + KuzuSettings, + Neo4jSettings, + NeptuneSettings, +) + +load_dotenv() + +os.environ.setdefault('DISABLE_NEPTUNE', 'True') + +GROUP_ID = 'graphium_test_group' +GROUP_ID_ALT = 'graphium_test_group_2' + + +def _provider_enabled(flag: str) -> bool: + return os.getenv(flag) is None + + +def _import_if_available( + provider: GraphProvider, + import_factory: Callable[[], type[GraphDriver]], +) -> type[GraphDriver] | None: + try: + return import_factory() + except ImportError as exc: + msg = f'Provider {provider.value} is enabled but unavailable: {exc}' + raise ImportError(msg) from exc + + +def available_providers() -> list[GraphProvider]: + """Return the graph providers that can be exercised for integration tests.""" + providers: list[GraphProvider] = [] + if _provider_enabled('DISABLE_NEO4J'): + if _import_if_available(GraphProvider.NEO4J, _resolve_neo4j_driver): + providers.append(GraphProvider.NEO4J) + if _provider_enabled('DISABLE_FALKORDB'): + if _import_if_available(GraphProvider.FALKORDB, _resolve_falkordb_driver): + providers.append(GraphProvider.FALKORDB) + if _provider_enabled('DISABLE_KUZU'): + if _import_if_available(GraphProvider.KUZU, _resolve_kuzu_driver): + providers.append(GraphProvider.KUZU) + if _provider_enabled('DISABLE_NEPTUNE'): + driver = _import_if_available(GraphProvider.NEPTUNE, _resolve_neptune_driver) + if driver: + providers.append(GraphProvider.NEPTUNE) + return providers + + +def make_driver(provider: GraphProvider) -> GraphDriver: + """Materialise a driver instance for the given provider.""" + if provider == GraphProvider.NEO4J: + settings = Neo4jSettings() + driver_cls = _resolve_neo4j_driver() + return driver_cls( + uri=settings.uri, + user=settings.user, + password=settings.password, + ) + if provider == GraphProvider.FALKORDB: + settings = FalkorSettings() + driver_cls = _resolve_falkordb_driver() + return driver_cls( + host=settings.host, + port=int(settings.port), + username=settings.username, + password=settings.password, + ) + if provider == GraphProvider.KUZU: + settings = KuzuSettings() + driver_cls = _resolve_kuzu_driver() + return driver_cls(db=settings.db) + if provider == GraphProvider.NEPTUNE: + settings = NeptuneSettings() + driver_cls = _resolve_neptune_driver() + return driver_cls( + host=settings.host or 'localhost', + port=int(settings.port), + aoss_host=settings.aoss_host, + ) + raise ValueError(f'Unsupported provider: {provider!r}') + + +async def prepare_driver(provider: GraphProvider) -> GraphDriver: + """Provision a driver and clear shared test groups.""" + driver = make_driver(provider) + await clear_data(driver, [GROUP_ID, GROUP_ID_ALT]) + return driver + + +async def teardown_driver(driver: GraphDriver) -> None: + await driver.close() + + +async def save_nodes(graph_driver: GraphDriver, *nodes: EntityNode) -> None: + await graph_driver.repositories.save_nodes(list(nodes)) + + +async def save_edges(graph_driver: GraphDriver, *edges: EntityEdge) -> None: + await graph_driver.repositories.save_edges(list(edges)) + + +async def delete_nodes(graph_driver: GraphDriver, *nodes: EntityNode) -> None: + for node in nodes: + await graph_driver.repositories.delete_node(node) + + +async def delete_edges(graph_driver: GraphDriver, *edges: EntityEdge) -> None: + for edge in edges: + await graph_driver.repositories.delete_edge(edge) + + +async def save_all(graph_driver: GraphDriver, *items: object) -> None: + repo = graph_driver.repositories + for item in items: + if isinstance(item, (EntityNode, EpisodicNode, CommunityNode)): + await repo.save_node(item) + elif isinstance(item, (EntityEdge, EpisodicEdge, CommunityEdge)): + await repo.save_edge(item) + else: + raise TypeError(f'Unsupported item type for save_all: {type(item)!r}') + + +async def delete_all(graph_driver: GraphDriver, *items: object) -> None: + repo = graph_driver.repositories + for item in items: + if isinstance(item, (EntityNode, EpisodicNode, CommunityNode)): + await repo.delete_node(item) + elif isinstance(item, (EntityEdge, EpisodicEdge, CommunityEdge)): + await repo.delete_edge(item) + else: + raise TypeError(f'Unsupported item type for delete_all: {type(item)!r}') + + +async def get_node_count(driver: GraphDriver, uuids: Iterable[str]) -> int: + results, _, _ = await driver.execute_query( + """ + MATCH (n) + WHERE n.uuid IN $uuids + RETURN COUNT(n) as count + """, + uuids=list(uuids), + ) + return int(results[0]['count']) + + +async def get_edge_count(driver: GraphDriver, uuids: Iterable[str]) -> int: + results, _, _ = await driver.execute_query( + """ + MATCH (n)-[e]->(m) + WHERE e.uuid IN $uuids + RETURN COUNT(e) as count + UNION ALL + MATCH (e:RelatesToNode_) + WHERE e.uuid IN $uuids + RETURN COUNT(e) as count + """, + uuids=list(uuids), + ) + return sum(int(result['count']) for result in results) + + +async def assert_entity_node_equals( + graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode +) -> None: + await graph_driver.repositories.entity_nodes.load_name_embedding(retrieved) + assert retrieved.uuid == sample.uuid + assert retrieved.name == sample.name + assert retrieved.group_id == sample.group_id + assert set(retrieved.labels) == set(sample.labels) + assert retrieved.created_at == sample.created_at + assert retrieved.name_embedding is not None + assert sample.name_embedding is not None + assert np.allclose(retrieved.name_embedding, sample.name_embedding) + assert retrieved.summary == sample.summary + assert retrieved.attributes_dict() == sample.attributes_dict() + + +async def assert_community_node_equals( + graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode +) -> None: + await graph_driver.repositories.community_nodes.load_name_embedding(retrieved) + assert retrieved.uuid == sample.uuid + assert retrieved.name == sample.name + assert retrieved.group_id == GROUP_ID + assert retrieved.created_at == sample.created_at + assert retrieved.name_embedding is not None + assert sample.name_embedding is not None + assert np.allclose(retrieved.name_embedding, sample.name_embedding) + assert retrieved.summary == sample.summary + + +async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode) -> None: + assert retrieved.uuid == sample.uuid + assert retrieved.name == sample.name + assert retrieved.group_id == GROUP_ID + assert retrieved.created_at == sample.created_at + assert retrieved.source == sample.source + assert retrieved.source_description == sample.source_description + assert retrieved.content == sample.content + assert retrieved.valid_at == sample.valid_at + assert set(retrieved.entity_edges) == set(sample.entity_edges) + + +async def assert_entity_edge_equals( + graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge +) -> None: + await graph_driver.repositories.entity_edges.load_fact_embedding(retrieved) + assert retrieved.uuid == sample.uuid + assert retrieved.group_id == sample.group_id + assert retrieved.created_at == sample.created_at + assert retrieved.source_node_uuid == sample.source_node_uuid + assert retrieved.target_node_uuid == sample.target_node_uuid + assert retrieved.name == sample.name + assert retrieved.fact == sample.fact + assert retrieved.fact_embedding is not None + assert sample.fact_embedding is not None + assert np.allclose(retrieved.fact_embedding, sample.fact_embedding) + assert retrieved.episodes == sample.episodes + assert retrieved.expired_at == sample.expired_at + assert retrieved.valid_at == sample.valid_at + assert retrieved.invalid_at == sample.invalid_at + assert retrieved.attributes_dict() == sample.attributes_dict() + + +async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge) -> None: + assert retrieved.uuid == sample.uuid + assert retrieved.group_id == sample.group_id + assert retrieved.created_at == sample.created_at + assert retrieved.source_node_uuid == sample.source_node_uuid + assert retrieved.target_node_uuid == sample.target_node_uuid + + +def _resolve_neo4j_driver() -> type[GraphDriver]: + from graphium_core.driver.neo4j_driver import Neo4jDriver + + return Neo4jDriver + + +def _resolve_falkordb_driver() -> type[GraphDriver]: + from graphium_core.driver.falkordb_driver import FalkorDriver + + return FalkorDriver + + +def _resolve_kuzu_driver() -> type[GraphDriver]: + from graphium_core.driver.kuzu_driver import KuzuDriver + + return KuzuDriver + + +def _resolve_neptune_driver() -> type[GraphDriver]: + from graphium_core.driver.neptune_driver import NeptuneDriver + + return NeptuneDriver diff --git a/tests/helpers_test.py b/tests/helpers_test.py deleted file mode 100644 index 1380f95..0000000 --- a/tests/helpers_test.py +++ /dev/null @@ -1,350 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Modified by the Graphium project. - -import os -from unittest.mock import Mock - -import numpy as np -import pytest -from dotenv import load_dotenv - -from graphium_core.driver.driver import GraphDriver, GraphProvider -from graphium_core.edges import CommunityEdge, EntityEdge, EpisodicEdge -from graphium_core.embedder.client import EmbedderClient -from graphium_core.nodes import CommunityNode, EntityNode, EpisodicNode -from graphium_core.orchestration.maintenance.graph_data_operations import clear_data -from graphium_core.search.lucene import sanitize as lucene_sanitize -from graphium_core.settings import ( - FalkorSettings, - KuzuSettings, - Neo4jSettings, - NeptuneSettings, -) - -load_dotenv() - -drivers: list[GraphProvider] = [] -if os.getenv('DISABLE_NEO4J') is None: - try: - from graphium_core.driver.neo4j_driver import Neo4jDriver - - drivers.append(GraphProvider.NEO4J) - except ImportError: - raise - -if os.getenv('DISABLE_FALKORDB') is None: - try: - from graphium_core.driver.falkordb_driver import FalkorDriver - - drivers.append(GraphProvider.FALKORDB) - except ImportError: - raise - -if os.getenv('DISABLE_KUZU') is None: - try: - from graphium_core.driver.kuzu_driver import KuzuDriver - - drivers.append(GraphProvider.KUZU) - except ImportError: - raise - -# Disable Neptune for now -os.environ['DISABLE_NEPTUNE'] = 'True' -if os.getenv('DISABLE_NEPTUNE') is None: - try: - from graphium_core.driver.neptune_driver import NeptuneDriver - - drivers.append(GraphProvider.NEPTUNE) - except ImportError: - raise - -_neo4j_settings = Neo4jSettings() -NEO4J_URI = _neo4j_settings.uri -NEO4J_USER = _neo4j_settings.user -NEO4J_PASSWORD = _neo4j_settings.password - -_falkor_settings = FalkorSettings() -FALKORDB_HOST = _falkor_settings.host -FALKORDB_PORT = str(_falkor_settings.port) -FALKORDB_USER = _falkor_settings.username -FALKORDB_PASSWORD = _falkor_settings.password - -_neptune_settings = NeptuneSettings() -NEPTUNE_HOST = _neptune_settings.host or 'localhost' -NEPTUNE_PORT = _neptune_settings.port -AOSS_HOST = _neptune_settings.aoss_host - -_kuzu_settings = KuzuSettings() -KUZU_DB = _kuzu_settings.db - -group_id = 'graphium_test_group' -group_id_2 = 'graphium_test_group_2' - - -def get_driver(provider: GraphProvider) -> GraphDriver: - if provider == GraphProvider.NEO4J: - return Neo4jDriver( - uri=NEO4J_URI, - user=NEO4J_USER, - password=NEO4J_PASSWORD, - ) - elif provider == GraphProvider.FALKORDB: - return FalkorDriver( - host=FALKORDB_HOST, - port=int(FALKORDB_PORT), - username=FALKORDB_USER, - password=FALKORDB_PASSWORD, - ) - elif provider == GraphProvider.KUZU: - driver = KuzuDriver( - db=KUZU_DB, - ) - return driver - elif provider == GraphProvider.NEPTUNE: - return NeptuneDriver( - host=NEPTUNE_HOST, - port=int(NEPTUNE_PORT), - aoss_host=AOSS_HOST, - ) - else: - raise ValueError(f'Driver {provider} not available') - - -@pytest.fixture(params=drivers) -async def graph_driver(request): - driver = request.param - graph_driver = get_driver(driver) - await clear_data(graph_driver, [group_id, group_id_2]) - try: - yield graph_driver # provide driver to the test - finally: - # always called, even if the test fails or raises - # await clean_up(graph_driver) - await graph_driver.close() - - -embedding_dim = 384 -embeddings = { - key: np.random.uniform(0.0, 0.9, embedding_dim).tolist() - for key in [ - 'Alice', - 'Bob', - 'Alice likes Bob', - 'test_entity_1', - 'test_entity_2', - 'test_entity_3', - 'test_entity_4', - 'test_entity_alice', - 'test_entity_bob', - 'test_entity_1 is a duplicate of test_entity_2', - 'test_entity_3 is a duplicate of test_entity_4', - 'test_entity_1 relates to test_entity_2', - 'test_entity_1 relates to test_entity_3', - 'test_entity_2 relates to test_entity_3', - 'test_entity_1 relates to test_entity_4', - 'test_entity_2 relates to test_entity_4', - 'test_entity_3 relates to test_entity_4', - 'test_entity_1 relates to test_entity_2', - 'test_entity_3 relates to test_entity_4', - 'test_entity_2 relates to test_entity_3', - 'test_community_1', - 'test_community_2', - ] -} -embeddings['Alice Smith'] = embeddings['Alice'] - - -@pytest.fixture -def mock_embedder(): - mock_model = Mock(spec=EmbedderClient) - - def mock_embed(input_data): - if isinstance(input_data, str): - return embeddings[input_data] - elif isinstance(input_data, list): - combined_input = ' '.join(input_data) - return embeddings[combined_input] - else: - raise ValueError(f'Unsupported input type: {type(input_data)}') - - mock_model.create.side_effect = mock_embed - return mock_model - - -def test_lucene_sanitize(): - # Call the function with test data - queries = [ - ( - 'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /', - '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/', - ), - ('this has no escape characters', 'this has no escape characters'), - ] - - for query, assert_result in queries: - result = lucene_sanitize(query) - assert assert_result == result - - -async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int: - results, _, _ = await driver.execute_query( - """ - MATCH (n) - WHERE n.uuid IN $uuids - RETURN COUNT(n) as count - """, - uuids=uuids, - ) - return int(results[0]['count']) - - -async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int: - results, _, _ = await driver.execute_query( - """ - MATCH (n)-[e]->(m) - WHERE e.uuid IN $uuids - RETURN COUNT(e) as count - UNION ALL - MATCH (e:RelatesToNode_) - WHERE e.uuid IN $uuids - RETURN COUNT(e) as count - """, - uuids=uuids, - ) - return sum(int(result['count']) for result in results) - - -async def print_graph(graph_driver: GraphDriver): - nodes, _, _ = await graph_driver.execute_query( - """ - MATCH (n) - RETURN n.uuid, n.name - """, - ) - print('Nodes:') - for node in nodes: - print(' ', node) - edges, _, _ = await graph_driver.execute_query( - """ - MATCH (n)-[e]->(m) - RETURN n.name, e.uuid, m.name - """, - ) - print('Edges:') - for edge in edges: - print(' ', edge) - - -async def save_nodes(graph_driver: GraphDriver, *nodes: EntityNode): - await graph_driver.repositories.save_nodes(list(nodes)) - - -async def save_edges(graph_driver: GraphDriver, *edges: EntityEdge): - await graph_driver.repositories.save_edges(list(edges)) - - -async def delete_nodes(graph_driver: GraphDriver, *nodes: EntityNode): - for node in nodes: - await graph_driver.repositories.delete_node(node) - - -async def delete_edges(graph_driver: GraphDriver, *edges: EntityEdge): - for edge in edges: - await graph_driver.repositories.delete_edge(edge) - - -async def save_all(graph_driver: GraphDriver, *items: object): - repo = graph_driver.repositories - for item in items: - if isinstance(item, EntityNode | EpisodicNode | CommunityNode): - await repo.save_node(item) - elif isinstance(item, EntityEdge | EpisodicEdge | CommunityEdge): - await repo.save_edge(item) - else: - raise TypeError(f'Unsupported item type for save_all: {type(item)!r}') - - -async def delete_all(graph_driver: GraphDriver, *items: object): - repo = graph_driver.repositories - for item in items: - if isinstance(item, EntityNode | EpisodicNode | CommunityNode): - await repo.delete_node(item) - elif isinstance(item, EntityEdge | EpisodicEdge | CommunityEdge): - await repo.delete_edge(item) - else: - raise TypeError(f'Unsupported item type for delete_all: {type(item)!r}') - - -async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode): - assert retrieved.uuid == sample.uuid - assert retrieved.name == sample.name - assert retrieved.group_id == group_id - assert retrieved.created_at == sample.created_at - assert retrieved.source == sample.source - assert retrieved.source_description == sample.source_description - assert retrieved.content == sample.content - assert retrieved.valid_at == sample.valid_at - assert set(retrieved.entity_edges) == set(sample.entity_edges) - - -async def assert_entity_node_equals( - graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode -): - await graph_driver.repositories.entity_nodes.load_name_embedding(retrieved) - assert retrieved.uuid == sample.uuid - assert retrieved.name == sample.name - assert retrieved.group_id == sample.group_id - assert set(retrieved.labels) == set(sample.labels) - assert retrieved.created_at == sample.created_at - assert retrieved.name_embedding is not None - assert sample.name_embedding is not None - assert np.allclose(retrieved.name_embedding, sample.name_embedding) - assert retrieved.summary == sample.summary - assert retrieved.attributes_dict() == sample.attributes_dict() - - -async def assert_community_node_equals( - graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode -): - await graph_driver.repositories.community_nodes.load_name_embedding(retrieved) - assert retrieved.uuid == sample.uuid - assert retrieved.name == sample.name - assert retrieved.group_id == group_id - assert retrieved.created_at == sample.created_at - assert retrieved.name_embedding is not None - assert sample.name_embedding is not None - assert np.allclose(retrieved.name_embedding, sample.name_embedding) - assert retrieved.summary == sample.summary - - -async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge): - assert retrieved.uuid == sample.uuid - assert retrieved.group_id == sample.group_id - assert retrieved.created_at == sample.created_at - assert retrieved.source_node_uuid == sample.source_node_uuid - assert retrieved.target_node_uuid == sample.target_node_uuid - - -async def assert_entity_edge_equals( - graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge -): - await graph_driver.repositories.entity_edges.load_fact_embedding(retrieved) - assert retrieved.uuid == sample.uuid - assert retrieved.group_id == sample.group_id - assert retrieved.created_at == sample.created_at - assert retrieved.source_node_uuid == sample.source_node_uuid - assert retrieved.target_node_uuid == sample.target_node_uuid - assert retrieved.name == sample.name - assert retrieved.fact == sample.fact - assert retrieved.fact_embedding is not None - assert sample.fact_embedding is not None - assert np.allclose(retrieved.fact_embedding, sample.fact_embedding) - assert retrieved.episodes == sample.episodes - assert retrieved.expired_at == sample.expired_at - assert retrieved.valid_at == sample.valid_at - assert retrieved.invalid_at == sample.invalid_at - assert retrieved.attributes_dict() == sample.attributes_dict() - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..e9caed9 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +from __future__ import annotations + +from tests.integration.shared.fixtures_services import graph_driver # noqa: F401 diff --git a/tests/integration/core/__init__.py b/tests/integration/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/core/shared/__init__.py b/tests/integration/core/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_entity_exclusion_int.py b/tests/integration/core/shared/test_entity_exclusion.py similarity index 93% rename from tests/test_entity_exclusion_int.py rename to tests/integration/core/shared/test_entity_exclusion.py index a9e730f..7fea749 100644 --- a/tests/test_entity_exclusion_int.py +++ b/tests/integration/core/shared/test_entity_exclusion.py @@ -21,7 +21,6 @@ load_llm_settings, ) from graphium_core.validation import validate_excluded_entity_types -from tests.helpers_test import drivers, get_driver pytestmark = pytest.mark.integration pytest_plugins = ('pytest_asyncio',) @@ -80,15 +79,11 @@ async def graphium_client_bundle() -> Generator[ @pytest.mark.asyncio -@pytest.mark.parametrize( - 'driver', - drivers, -) -async def test_exclude_default_entity_type(driver, graphium_client_bundle): +async def test_exclude_default_entity_type(graph_driver, graphium_client_bundle): """Test excluding the default 'Entity' type while keeping custom types.""" llm_client, embedder, reranker = graphium_client_bundle graphium = Graphium( - graph_driver=get_driver(driver), + graph_driver=graph_driver, llm_client=llm_client, embedder=embedder, cross_encoder=reranker, @@ -143,15 +138,11 @@ async def test_exclude_default_entity_type(driver, graphium_client_bundle): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'driver', - drivers, -) -async def test_exclude_specific_custom_types(driver, graphium_client_bundle): +async def test_exclude_specific_custom_types(graph_driver, graphium_client_bundle): """Test excluding specific custom entity types while keeping others.""" llm_client, embedder, reranker = graphium_client_bundle graphium = Graphium( - graph_driver=get_driver(driver), + graph_driver=graph_driver, llm_client=llm_client, embedder=embedder, cross_encoder=reranker, @@ -212,15 +203,11 @@ async def test_exclude_specific_custom_types(driver, graphium_client_bundle): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'driver', - drivers, -) -async def test_exclude_all_types(driver, graphium_client_bundle): +async def test_exclude_all_types(graph_driver, graphium_client_bundle): """Test excluding all entity types (edge case).""" llm_client, embedder, reranker = graphium_client_bundle graphium = Graphium( - graph_driver=get_driver(driver), + graph_driver=graph_driver, llm_client=llm_client, embedder=embedder, cross_encoder=reranker, @@ -266,15 +253,11 @@ async def test_exclude_all_types(driver, graphium_client_bundle): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'driver', - drivers, -) -async def test_exclude_no_types(driver, graphium_client_bundle): +async def test_exclude_no_types(graph_driver, graphium_client_bundle): """Test normal behavior when no types are excluded (baseline test).""" llm_client, embedder, reranker = graphium_client_bundle graphium = Graphium( - graph_driver=get_driver(driver), + graph_driver=graph_driver, llm_client=llm_client, embedder=embedder, cross_encoder=reranker, @@ -354,15 +337,13 @@ def test_validation_invalid_excluded_types(): @pytest.mark.asyncio -@pytest.mark.parametrize( - 'driver', - drivers, -) -async def test_excluded_types_parameter_validation_in_add_episode(driver, graphium_client_bundle): +async def test_excluded_types_parameter_validation_in_add_episode( + graph_driver, graphium_client_bundle +): """Test that add_episode validates excluded_entity_types parameter.""" llm_client, embedder, reranker = graphium_client_bundle graphium = Graphium( - graph_driver=get_driver(driver), + graph_driver=graph_driver, llm_client=llm_client, embedder=embedder, cross_encoder=reranker, diff --git a/tests/test_graphium_int.py b/tests/integration/core/shared/test_graphium_bootstrap.py similarity index 97% rename from tests/test_graphium_int.py rename to tests/integration/core/shared/test_graphium_bootstrap.py index b2dc581..9f82ccb 100644 --- a/tests/test_graphium_int.py +++ b/tests/integration/core/shared/test_graphium_bootstrap.py @@ -6,11 +6,11 @@ import pytest +from graphium_core.driver.driver import GraphProvider from graphium_core.graphium import Graphium from graphium_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters from graphium_core.search.search_helpers import search_results_to_context_string from graphium_core.utils.datetime_utils import utc_now -from tests.helpers_test import GraphProvider pytestmark = pytest.mark.integration pytest_plugins = ('pytest_asyncio',) diff --git a/tests/test_graphium_mock.py b/tests/integration/core/shared/test_ingestion_pipeline.py similarity index 99% rename from tests/test_graphium_mock.py rename to tests/integration/core/shared/test_ingestion_pipeline.py index f2bd3ab..2696789 100644 --- a/tests/test_graphium_mock.py +++ b/tests/integration/core/shared/test_ingestion_pipeline.py @@ -8,6 +8,7 @@ import pytest from graphium_core.cross_encoder import RerankerClient +from graphium_core.driver.driver import GraphProvider from graphium_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphium_core.graphium import Graphium from graphium_core.llm_client import LLMClient @@ -43,8 +44,7 @@ node_fulltext_search, node_similarity_search, ) -from tests.helpers_test import ( - GraphProvider, +from tests.helpers.services import ( assert_entity_edge_equals, assert_entity_node_equals, assert_episodic_edge_equals, @@ -52,8 +52,8 @@ delete_all, get_edge_count, get_node_count, - group_id, - group_id_2, + GROUP_ID as group_id, + GROUP_ID_ALT as group_id_2, save_all, save_edges, save_nodes, diff --git a/tests/test_edge_int.py b/tests/integration/core/shared/test_repository_edges.py similarity index 88% rename from tests/test_edge_int.py rename to tests/integration/core/shared/test_repository_edges.py index 42cdbd4..6dcaa8a 100644 --- a/tests/test_edge_int.py +++ b/tests/integration/core/shared/test_repository_edges.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Modified by the Graphium project. -import logging -import sys from datetime import UTC, datetime import numpy as np @@ -10,33 +8,11 @@ from graphium_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphium_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode -from tests.helpers_test import get_edge_count, get_node_count, group_id +from tests.helpers.services import get_edge_count, get_node_count, GROUP_ID -pytest_plugins = ('pytest_asyncio',) pytestmark = pytest.mark.integration -def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger - - @pytest.mark.asyncio async def test_episodic_edge(graph_driver, mock_embedder): now = datetime.now(UTC) @@ -52,7 +28,7 @@ async def test_episodic_edge(graph_driver, mock_embedder): source_description='conversation message', content='Alice likes Bob', entity_edges=[], - group_id=group_id, + group_id=GROUP_ID, ) node_count = await get_node_count(graph_driver, [episode_node.uuid]) assert node_count == 0 @@ -66,7 +42,7 @@ async def test_episodic_edge(graph_driver, mock_embedder): labels=[], created_at=now, summary='Alice summary', - group_id=group_id, + group_id=GROUP_ID, ) await alice_node.generate_name_embedding(mock_embedder) node_count = await get_node_count(graph_driver, [alice_node.uuid]) @@ -80,7 +56,7 @@ async def test_episodic_edge(graph_driver, mock_embedder): source_node_uuid=episode_node.uuid, target_node_uuid=alice_node.uuid, created_at=now, - group_id=group_id, + group_id=GROUP_ID, ) edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid]) assert edge_count == 0 @@ -94,7 +70,7 @@ async def test_episodic_edge(graph_driver, mock_embedder): assert retrieved.source_node_uuid == episode_node.uuid assert retrieved.target_node_uuid == alice_node.uuid assert retrieved.created_at == now - assert retrieved.group_id == group_id + assert retrieved.group_id == GROUP_ID # Get edge by uuids retrieved = await repositories.episodic_edges.get_by_uuids([episodic_edge.uuid]) @@ -103,16 +79,16 @@ async def test_episodic_edge(graph_driver, mock_embedder): assert retrieved[0].source_node_uuid == episode_node.uuid assert retrieved[0].target_node_uuid == alice_node.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get edge by group ids - retrieved = await repositories.episodic_edges.get_by_group_ids([group_id], limit=2) + retrieved = await repositories.episodic_edges.get_by_group_ids([GROUP_ID], limit=2) assert len(retrieved) == 1 assert retrieved[0].uuid == episodic_edge.uuid assert retrieved[0].source_node_uuid == episode_node.uuid assert retrieved[0].target_node_uuid == alice_node.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get episodic node by entity node uuid retrieved = await repositories.episodic_nodes.get_by_entity_node_uuid(alice_node.uuid) @@ -120,7 +96,7 @@ async def test_episodic_edge(graph_driver, mock_embedder): assert retrieved[0].uuid == episode_node.uuid assert retrieved[0].name == 'test_episode' assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Delete edge by uuid await repositories.delete_edge(episodic_edge) @@ -141,7 +117,6 @@ async def test_episodic_edge(graph_driver, mock_embedder): node_count = await get_node_count(graph_driver, [alice_node.uuid]) assert node_count == 0 - await graph_driver.close() @pytest.mark.asyncio @@ -155,7 +130,7 @@ async def test_entity_edge(graph_driver, mock_embedder): labels=[], created_at=now, summary='Alice summary', - group_id=group_id, + group_id=GROUP_ID, ) await alice_node.generate_name_embedding(mock_embedder) node_count = await get_node_count(graph_driver, [alice_node.uuid]) @@ -166,7 +141,7 @@ async def test_entity_edge(graph_driver, mock_embedder): # Create entity node bob_node = EntityNode( - name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id + name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=GROUP_ID ) await bob_node.generate_name_embedding(mock_embedder) node_count = await get_node_count(graph_driver, [bob_node.uuid]) @@ -186,7 +161,7 @@ async def test_entity_edge(graph_driver, mock_embedder): expired_at=now, valid_at=now, invalid_at=now, - group_id=group_id, + group_id=GROUP_ID, ) edge_embedding = await entity_edge.generate_embedding(mock_embedder) edge_count = await get_edge_count(graph_driver, [entity_edge.uuid]) @@ -201,7 +176,7 @@ async def test_entity_edge(graph_driver, mock_embedder): assert retrieved.source_node_uuid == alice_node.uuid assert retrieved.target_node_uuid == bob_node.uuid assert retrieved.created_at == now - assert retrieved.group_id == group_id + assert retrieved.group_id == GROUP_ID # Get edge by uuids retrieved = await repositories.entity_edges.get_by_uuids([entity_edge.uuid]) @@ -210,16 +185,16 @@ async def test_entity_edge(graph_driver, mock_embedder): assert retrieved[0].source_node_uuid == alice_node.uuid assert retrieved[0].target_node_uuid == bob_node.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get edge by group ids - retrieved = await repositories.entity_edges.get_by_group_ids([group_id], limit=2) + retrieved = await repositories.entity_edges.get_by_group_ids([GROUP_ID], limit=2) assert len(retrieved) == 1 assert retrieved[0].uuid == entity_edge.uuid assert retrieved[0].source_node_uuid == alice_node.uuid assert retrieved[0].target_node_uuid == bob_node.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get edge by node uuid retrieved = await repositories.entity_edges.get_by_node_uuid(alice_node.uuid) @@ -228,7 +203,7 @@ async def test_entity_edge(graph_driver, mock_embedder): assert retrieved[0].source_node_uuid == alice_node.uuid assert retrieved[0].target_node_uuid == bob_node.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get fact embedding await repositories.entity_edges.load_fact_embedding(entity_edge) @@ -279,7 +254,6 @@ async def test_entity_edge(graph_driver, mock_embedder): node_count = await get_node_count(graph_driver, [bob_node.uuid]) assert node_count == 0 - await graph_driver.close() @pytest.mark.asyncio @@ -290,7 +264,7 @@ async def test_community_edge(graph_driver, mock_embedder): # Create community node community_node_1 = CommunityNode( name='test_community_1', - group_id=group_id, + group_id=GROUP_ID, summary='Community A summary', ) await community_node_1.generate_name_embedding(mock_embedder) @@ -303,7 +277,7 @@ async def test_community_edge(graph_driver, mock_embedder): # Create community node community_node_2 = CommunityNode( name='test_community_2', - group_id=group_id, + group_id=GROUP_ID, summary='Community B summary', ) await community_node_2.generate_name_embedding(mock_embedder) @@ -315,7 +289,7 @@ async def test_community_edge(graph_driver, mock_embedder): # Create entity node alice_node = EntityNode( - name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id + name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=GROUP_ID ) await alice_node.generate_name_embedding(mock_embedder) node_count = await get_node_count(graph_driver, [alice_node.uuid]) @@ -329,7 +303,7 @@ async def test_community_edge(graph_driver, mock_embedder): source_node_uuid=community_node_1.uuid, target_node_uuid=community_node_2.uuid, created_at=now, - group_id=group_id, + group_id=GROUP_ID, ) edge_count = await get_edge_count(graph_driver, [community_edge.uuid]) assert edge_count == 0 @@ -343,7 +317,7 @@ async def test_community_edge(graph_driver, mock_embedder): assert retrieved.source_node_uuid == community_node_1.uuid assert retrieved.target_node_uuid == community_node_2.uuid assert retrieved.created_at == now - assert retrieved.group_id == group_id + assert retrieved.group_id == GROUP_ID # Get edge by uuids retrieved = await repositories.community_edges.get_by_uuids([community_edge.uuid]) @@ -352,16 +326,16 @@ async def test_community_edge(graph_driver, mock_embedder): assert retrieved[0].source_node_uuid == community_node_1.uuid assert retrieved[0].target_node_uuid == community_node_2.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Get edge by group ids - retrieved = await repositories.community_edges.get_by_group_ids([group_id], limit=1) + retrieved = await repositories.community_edges.get_by_group_ids([GROUP_ID], limit=1) assert len(retrieved) == 1 assert retrieved[0].uuid == community_edge.uuid assert retrieved[0].source_node_uuid == community_node_1.uuid assert retrieved[0].target_node_uuid == community_node_2.uuid assert retrieved[0].created_at == now - assert retrieved[0].group_id == group_id + assert retrieved[0].group_id == GROUP_ID # Delete edge by uuid await repositories.delete_edge(community_edge) @@ -384,5 +358,3 @@ async def test_community_edge(graph_driver, mock_embedder): await repositories.delete_node(community_node_2) node_count = await get_node_count(graph_driver, [community_node_2.uuid]) assert node_count == 0 - - await graph_driver.close() diff --git a/tests/test_node_int.py b/tests/integration/core/shared/test_repository_nodes.py similarity index 92% rename from tests/test_node_int.py rename to tests/integration/core/shared/test_repository_nodes.py index 5f111b8..0f9ce28 100644 --- a/tests/test_node_int.py +++ b/tests/integration/core/shared/test_repository_nodes.py @@ -12,20 +12,18 @@ EpisodeType, EpisodicNode, ) -from tests.helpers_test import ( +from tests.helpers.services import ( assert_community_node_equals, assert_entity_node_equals, assert_episodic_node_equals, get_node_count, - group_id, + GROUP_ID, ) pytestmark = pytest.mark.integration created_at = datetime.now(UTC) -deleted_at = created_at + timedelta(days=3) valid_at = created_at + timedelta(days=1) -invalid_at = created_at + timedelta(days=2) @pytest.fixture @@ -33,7 +31,7 @@ def sample_entity_node(): return EntityNode( uuid=str(uuid4()), name='Test Entity', - group_id=group_id, + group_id=GROUP_ID, labels=['Entity', 'Person'], created_at=created_at, name_embedding=[0.5] * 1024, @@ -50,7 +48,7 @@ def sample_episodic_node(): return EpisodicNode( uuid=str(uuid4()), name='Episode 1', - group_id=group_id, + group_id=GROUP_ID, created_at=created_at, source=EpisodeType.text, source_description='Test source', @@ -65,7 +63,7 @@ def sample_community_node(): return CommunityNode( uuid=str(uuid4()), name='Community A', - group_id=group_id, + group_id=GROUP_ID, created_at=created_at, name_embedding=[0.5] * 1024, summary='Community summary', @@ -94,7 +92,7 @@ async def test_entity_node(sample_entity_node, graph_driver): # Get node by group ids retrieved = await repositories.entity_nodes.get_by_group_ids( - [group_id], limit=2, with_embeddings=True + [GROUP_ID], limit=2, with_embeddings=True ) assert len(retrieved) == 1 await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node) @@ -116,7 +114,7 @@ async def test_entity_node(sample_entity_node, graph_driver): await repositories.save_node(sample_entity_node) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 1 - await repositories.entity_nodes.delete_by_group_id(group_id) + await repositories.entity_nodes.delete_by_group_id(GROUP_ID) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 0 @@ -144,7 +142,7 @@ async def test_community_node(sample_community_node, graph_driver): await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node) # Get node by group ids - retrieved = await repositories.community_nodes.get_by_group_ids([group_id], limit=2) + retrieved = await repositories.community_nodes.get_by_group_ids([GROUP_ID], limit=2) assert len(retrieved) == 1 await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node) @@ -165,7 +163,7 @@ async def test_community_node(sample_community_node, graph_driver): await repositories.save_node(sample_community_node) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 1 - await repositories.community_nodes.delete_by_group_id(group_id) + await repositories.community_nodes.delete_by_group_id(GROUP_ID) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 0 @@ -193,7 +191,7 @@ async def test_episodic_node(sample_episodic_node, graph_driver): await assert_episodic_node_equals(retrieved[0], sample_episodic_node) # Get node by group ids - retrieved = await repositories.episodic_nodes.get_by_group_ids([group_id], limit=2) + retrieved = await repositories.episodic_nodes.get_by_group_ids([GROUP_ID], limit=2) assert len(retrieved) == 1 await assert_episodic_node_equals(retrieved[0], sample_episodic_node) @@ -214,7 +212,7 @@ async def test_episodic_node(sample_episodic_node, graph_driver): await repositories.save_node(sample_episodic_node) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 1 - await repositories.episodic_nodes.delete_by_group_id(group_id) + await repositories.episodic_nodes.delete_by_group_id(GROUP_ID) node_count = await get_node_count(graph_driver, [uuid]) assert node_count == 0 diff --git a/tests/integration/falkordb/__init__.py b/tests/integration/falkordb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/falkordb/test_placeholder.py b/tests/integration/falkordb/test_placeholder.py new file mode 100644 index 0000000..ad8ad4e --- /dev/null +++ b/tests/integration/falkordb/test_placeholder.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +import pytest + + +@pytest.mark.integration +@pytest.mark.falkordb +def test_falkordb_pipeline_placeholder(): + pytest.skip('Placeholder for FalkorDB ingestion/search integration suite.') diff --git a/tests/integration/kuzu/__init__.py b/tests/integration/kuzu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/kuzu/test_placeholder.py b/tests/integration/kuzu/test_placeholder.py new file mode 100644 index 0000000..c9c60d0 --- /dev/null +++ b/tests/integration/kuzu/test_placeholder.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +import pytest + + +@pytest.mark.integration +@pytest.mark.kuzu +def test_kuzu_pipeline_placeholder(): + pytest.skip('Placeholder for Kuzu-backed integration coverage.') diff --git a/tests/llm_client/test_anthropic_client_int.py b/tests/integration/llm_client/test_anthropic_client.py similarity index 100% rename from tests/llm_client/test_anthropic_client_int.py rename to tests/integration/llm_client/test_anthropic_client.py diff --git a/tests/integration/neo4j/__init__.py b/tests/integration/neo4j/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/neo4j/test_placeholder.py b/tests/integration/neo4j/test_placeholder.py new file mode 100644 index 0000000..6852cf5 --- /dev/null +++ b/tests/integration/neo4j/test_placeholder.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +import pytest + + +@pytest.mark.integration +@pytest.mark.neo4j +def test_neo4j_pipeline_placeholder(): + pytest.skip('Pending Neo4j-backed integration coverage (ingest/search smoke).') diff --git a/tests/integration/shared/__init__.py b/tests/integration/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/shared/fixtures_services.py b/tests/integration/shared/fixtures_services.py new file mode 100644 index 0000000..b88fcbc --- /dev/null +++ b/tests/integration/shared/fixtures_services.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +"""Service-scoped fixtures reused across integration suites.""" + +from __future__ import annotations + +import pytest + +from graphium_core.driver.driver import GraphProvider + +from tests.helpers import services + +PROVIDER_MARKER = { + GraphProvider.NEO4J: 'neo4j', + GraphProvider.FALKORDB: 'falkordb', + GraphProvider.KUZU: 'kuzu', + GraphProvider.NEPTUNE: 'neptune', +} + + +def _integration_providers() -> list[GraphProvider]: + providers = services.available_providers() + if not providers: + pytest.skip('No graph providers enabled for integration tests') + return providers + + +@pytest.fixture(params=_integration_providers()) +async def graph_driver(request: pytest.FixtureRequest): + provider: GraphProvider = request.param + marker = PROVIDER_MARKER.get(provider) + if marker: + request.node.add_marker(marker) + driver = await services.prepare_driver(provider) + try: + yield driver + finally: + await services.teardown_driver(driver) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/cross_encoder/__init__.py b/tests/unit/cross_encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cross_encoder/test_bge_reranker_client.py b/tests/unit/cross_encoder/test_bge_reranker_client.py similarity index 100% rename from tests/cross_encoder/test_bge_reranker_client.py rename to tests/unit/cross_encoder/test_bge_reranker_client.py diff --git a/tests/cross_encoder/test_gemini_reranker_client.py b/tests/unit/cross_encoder/test_gemini_reranker_client.py similarity index 100% rename from tests/cross_encoder/test_gemini_reranker_client.py rename to tests/unit/cross_encoder/test_gemini_reranker_client.py diff --git a/tests/unit/drivers/__init__.py b/tests/unit/drivers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/driver/test_falkordb_driver.py b/tests/unit/drivers/test_falkordb_driver.py similarity index 100% rename from tests/driver/test_falkordb_driver.py rename to tests/unit/drivers/test_falkordb_driver.py diff --git a/tests/unit/embedder/__init__.py b/tests/unit/embedder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/embedder/embedder_fixtures.py b/tests/unit/embedder/embedder_fixtures.py similarity index 100% rename from tests/embedder/embedder_fixtures.py rename to tests/unit/embedder/embedder_fixtures.py diff --git a/tests/embedder/test_embeddinggemma.py b/tests/unit/embedder/test_embeddinggemma.py similarity index 100% rename from tests/embedder/test_embeddinggemma.py rename to tests/unit/embedder/test_embeddinggemma.py diff --git a/tests/embedder/test_gemini.py b/tests/unit/embedder/test_gemini.py similarity index 99% rename from tests/embedder/test_gemini.py rename to tests/unit/embedder/test_gemini.py index 65b562f..c2ae290 100644 --- a/tests/embedder/test_gemini.py +++ b/tests/unit/embedder/test_gemini.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # Modified by the Graphium project. -# Running tests: pytest -xvs tests/embedder/test_gemini.py +# Running tests: uv run pytest tests/unit/embedder/test_gemini.py -xvs from collections.abc import Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -from embedder_fixtures import create_embedding_values + +from .embedder_fixtures import create_embedding_values from graphium_core.embedder import GeminiEmbedder, GeminiEmbedderConfig from graphium_core.embedder.providers.gemini import DEFAULT_EMBEDDING_MODEL diff --git a/tests/embedder/test_openai.py b/tests/unit/embedder/test_openai.py similarity index 98% rename from tests/embedder/test_openai.py rename to tests/unit/embedder/test_openai.py index 2334e09..c34181c 100644 --- a/tests/embedder/test_openai.py +++ b/tests/unit/embedder/test_openai.py @@ -9,7 +9,8 @@ from graphium_core.embedder import OpenAIEmbedder, OpenAIEmbedderConfig from graphium_core.embedder.providers.openai import DEFAULT_EMBEDDING_MODEL -from tests.embedder.embedder_fixtures import create_embedding_values + +from .embedder_fixtures import create_embedding_values def create_openai_embedding(multiplier: float = 0.1) -> MagicMock: diff --git a/tests/embedder/test_voyage.py b/tests/unit/embedder/test_voyage.py similarity index 98% rename from tests/embedder/test_voyage.py rename to tests/unit/embedder/test_voyage.py index 250c584..b9f606b 100644 --- a/tests/embedder/test_voyage.py +++ b/tests/unit/embedder/test_voyage.py @@ -9,7 +9,8 @@ from graphium_core.embedder import VoyageAIEmbedder, VoyageAIEmbedderConfig from graphium_core.embedder.providers.voyage import DEFAULT_EMBEDDING_MODEL -from tests.embedder.embedder_fixtures import create_embedding_values + +from .embedder_fixtures import create_embedding_values @pytest.fixture diff --git a/tests/unit/llm_client/__init__.py b/tests/unit/llm_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm_client/test_anthropic_client.py b/tests/unit/llm_client/test_anthropic_client.py similarity index 100% rename from tests/llm_client/test_anthropic_client.py rename to tests/unit/llm_client/test_anthropic_client.py diff --git a/tests/llm_client/test_client.py b/tests/unit/llm_client/test_client.py similarity index 100% rename from tests/llm_client/test_client.py rename to tests/unit/llm_client/test_client.py diff --git a/tests/llm_client/test_errors.py b/tests/unit/llm_client/test_errors.py similarity index 100% rename from tests/llm_client/test_errors.py rename to tests/unit/llm_client/test_errors.py diff --git a/tests/llm_client/test_gemini_client.py b/tests/unit/llm_client/test_gemini_client.py similarity index 100% rename from tests/llm_client/test_gemini_client.py rename to tests/unit/llm_client/test_gemini_client.py diff --git a/tests/llm_client/test_groq_client.py b/tests/unit/llm_client/test_groq_client.py similarity index 100% rename from tests/llm_client/test_groq_client.py rename to tests/unit/llm_client/test_groq_client.py diff --git a/tests/llm_client/test_litellm_client.py b/tests/unit/llm_client/test_litellm_client.py similarity index 100% rename from tests/llm_client/test_litellm_client.py rename to tests/unit/llm_client/test_litellm_client.py diff --git a/tests/llm_client/test_pydantic_ai_adapter.py b/tests/unit/llm_client/test_pydantic_ai_adapter.py similarity index 100% rename from tests/llm_client/test_pydantic_ai_adapter.py rename to tests/unit/llm_client/test_pydantic_ai_adapter.py diff --git a/tests/llm_client/test_structured_output.py b/tests/unit/llm_client/test_structured_output.py similarity index 100% rename from tests/llm_client/test_structured_output.py rename to tests/unit/llm_client/test_structured_output.py diff --git a/tests/unit/mcp/__init__.py b/tests/unit/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mcp/test_episode_queue.py b/tests/unit/mcp/test_episode_queue.py similarity index 100% rename from tests/mcp/test_episode_queue.py rename to tests/unit/mcp/test_episode_queue.py diff --git a/tests/unit/orchestration/__init__.py b/tests/unit/orchestration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/orchestration/test_bulk.py b/tests/unit/orchestration/test_bulk.py similarity index 100% rename from tests/orchestration/test_bulk.py rename to tests/unit/orchestration/test_bulk.py diff --git a/tests/orchestration/test_bulk_serialization.py b/tests/unit/orchestration/test_bulk_serialization.py similarity index 100% rename from tests/orchestration/test_bulk_serialization.py rename to tests/unit/orchestration/test_bulk_serialization.py diff --git a/tests/orchestration/test_episode_orchestrator.py b/tests/unit/orchestration/test_episode_orchestrator.py similarity index 100% rename from tests/orchestration/test_episode_orchestrator.py rename to tests/unit/orchestration/test_episode_orchestrator.py diff --git a/tests/test_graphium_factory_usage.py b/tests/unit/orchestration/test_initializer_factory.py similarity index 100% rename from tests/test_graphium_factory_usage.py rename to tests/unit/orchestration/test_initializer_factory.py diff --git a/tests/orchestration/test_node_operations_sequence.py b/tests/unit/orchestration/test_node_operations_sequence.py similarity index 100% rename from tests/orchestration/test_node_operations_sequence.py rename to tests/unit/orchestration/test_node_operations_sequence.py diff --git a/tests/unit/providers/__init__.py b/tests/unit/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/providers/test_factory.py b/tests/unit/providers/test_factory.py similarity index 100% rename from tests/providers/test_factory.py rename to tests/unit/providers/test_factory.py diff --git a/tests/unit/search/__init__.py b/tests/unit/search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/search/test_edge_search_orchestration.py b/tests/unit/search/test_edge_search_orchestration.py similarity index 100% rename from tests/search/test_edge_search_orchestration.py rename to tests/unit/search/test_edge_search_orchestration.py diff --git a/tests/unit/search/test_lucene_utils.py b/tests/unit/search/test_lucene_utils.py new file mode 100644 index 0000000..67761f5 --- /dev/null +++ b/tests/unit/search/test_lucene_utils.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# Modified by the Graphium project. + +from graphium_core.search.lucene import sanitize as lucene_sanitize + + +def test_lucene_sanitize(): + queries = [ + ( + 'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /', + '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/', + ), + ('this has no escape characters', 'this has no escape characters'), + ] + + for query, expected in queries: + assert lucene_sanitize(query) == expected diff --git a/tests/search/test_search_filters.py b/tests/unit/search/test_search_filters.py similarity index 100% rename from tests/search/test_search_filters.py rename to tests/unit/search/test_search_filters.py diff --git a/tests/search/test_search_helpers.py b/tests/unit/search/test_search_helpers.py similarity index 100% rename from tests/search/test_search_helpers.py rename to tests/unit/search/test_search_helpers.py diff --git a/tests/search/test_search_utils_edges.py b/tests/unit/search/test_search_utils_edges.py similarity index 100% rename from tests/search/test_search_utils_edges.py rename to tests/unit/search/test_search_utils_edges.py diff --git a/tests/search/test_search_utils_filters.py b/tests/unit/search/test_search_utils_filters.py similarity index 100% rename from tests/search/test_search_utils_filters.py rename to tests/unit/search/test_search_utils_filters.py diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/maintenance/test_bulk_utils.py b/tests/unit/utils/maintenance/test_bulk_utils.py similarity index 100% rename from tests/utils/maintenance/test_bulk_utils.py rename to tests/unit/utils/maintenance/test_bulk_utils.py diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/unit/utils/maintenance/test_edge_operations.py similarity index 100% rename from tests/utils/maintenance/test_edge_operations.py rename to tests/unit/utils/maintenance/test_edge_operations.py diff --git a/tests/utils/maintenance/test_node_operations.py b/tests/unit/utils/maintenance/test_node_operations.py similarity index 100% rename from tests/utils/maintenance/test_node_operations.py rename to tests/unit/utils/maintenance/test_node_operations.py diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/unit/utils/maintenance/test_temporal_operations.py similarity index 100% rename from tests/utils/maintenance/test_temporal_operations_int.py rename to tests/unit/utils/maintenance/test_temporal_operations.py diff --git a/tests/utils/search/search_utils_test.py b/tests/unit/utils/search/test_hybrid_search.py similarity index 100% rename from tests/utils/search/search_utils_test.py rename to tests/unit/utils/search/test_hybrid_search.py diff --git a/tests/test_text_utils.py b/tests/unit/utils/test_text_utils.py similarity index 100% rename from tests/test_text_utils.py rename to tests/unit/utils/test_text_utils.py From ea22fcdf7ff6c230254ff644dce06615ed77a3da Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:23:56 -0700 Subject: [PATCH 2/7] gracefully skip integration tests when providers disabled --- tests/integration/shared/fixtures_services.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/integration/shared/fixtures_services.py b/tests/integration/shared/fixtures_services.py index b88fcbc..3ee3406 100644 --- a/tests/integration/shared/fixtures_services.py +++ b/tests/integration/shared/fixtures_services.py @@ -18,17 +18,21 @@ GraphProvider.NEPTUNE: 'neptune', } +_NO_PROVIDER_MSG = 'No graph providers enabled for integration tests' -def _integration_providers() -> list[GraphProvider]: + +def _integration_providers() -> list[object]: providers = services.available_providers() - if not providers: - pytest.skip('No graph providers enabled for integration tests') - return providers + if providers: + return providers + return [pytest.param(None, marks=pytest.mark.skip(_NO_PROVIDER_MSG))] @pytest.fixture(params=_integration_providers()) async def graph_driver(request: pytest.FixtureRequest): - provider: GraphProvider = request.param + provider = request.param + if provider is None: + pytest.skip(_NO_PROVIDER_MSG) marker = PROVIDER_MARKER.get(provider) if marker: request.node.add_marker(marker) From 750f4e79265bbf6b61ac0edda4ab97a22ba2b728 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:38:32 -0700 Subject: [PATCH 3/7] address coderabbit feedback for test restructure --- TEST_RESTRUCTURE_PLAN.md | 4 ++-- tests/helpers/embeddings.py | 5 ++--- tests/helpers/services.py | 14 +++++++++++--- tests/integration/shared/fixtures_services.py | 6 +++--- tests/unit/search/test_lucene_utils.py | 15 +++++++++------ 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/TEST_RESTRUCTURE_PLAN.md b/TEST_RESTRUCTURE_PLAN.md index f4bd37b..95007f3 100644 --- a/TEST_RESTRUCTURE_PLAN.md +++ b/TEST_RESTRUCTURE_PLAN.md @@ -16,7 +16,7 @@ ## Proposed Directory Layout -``` +```text tests/ ├─ unit/ │ ├─ orchestration/ @@ -58,7 +58,7 @@ tests/ Markers to add in `pytest.ini`: -``` +```ini [pytest] markers = integration: tests that require external services diff --git a/tests/helpers/embeddings.py b/tests/helpers/embeddings.py index 8a23a12..12b17d4 100644 --- a/tests/helpers/embeddings.py +++ b/tests/helpers/embeddings.py @@ -39,9 +39,8 @@ def default_embeddings() -> dict[str, list[float]]: 'test_community_1', 'test_community_2', ] - embeddings = { - key: np.random.uniform(0.0, 0.9, EMBEDDING_DIM).tolist() for key in keys - } + rng = np.random.Generator(np.random.PCG64(42)) + embeddings = {key: rng.uniform(0.0, 0.9, EMBEDDING_DIM).tolist() for key in keys} embeddings['Alice Smith'] = embeddings['Alice'] return embeddings diff --git a/tests/helpers/services.py b/tests/helpers/services.py index b19ef56..f8f5e0f 100644 --- a/tests/helpers/services.py +++ b/tests/helpers/services.py @@ -37,6 +37,10 @@ GROUP_ID = 'graphium_test_group' GROUP_ID_ALT = 'graphium_test_group_2' +_MSG_UNSUPPORTED_PROVIDER = 'Unsupported provider: {provider!r}' +_MSG_UNSUPPORTED_SAVE = 'Unsupported item type for save_all: {item_type!r}' +_MSG_UNSUPPORTED_DELETE = 'Unsupported item type for delete_all: {item_type!r}' + def _provider_enabled(flag: str) -> bool: return os.getenv(flag) is None @@ -103,7 +107,7 @@ def make_driver(provider: GraphProvider) -> GraphDriver: port=int(settings.port), aoss_host=settings.aoss_host, ) - raise ValueError(f'Unsupported provider: {provider!r}') + raise ValueError(_MSG_UNSUPPORTED_PROVIDER.format(provider=provider)) async def prepare_driver(provider: GraphProvider) -> GraphDriver: @@ -143,7 +147,9 @@ async def save_all(graph_driver: GraphDriver, *items: object) -> None: elif isinstance(item, (EntityEdge, EpisodicEdge, CommunityEdge)): await repo.save_edge(item) else: - raise TypeError(f'Unsupported item type for save_all: {type(item)!r}') + raise TypeError( + _MSG_UNSUPPORTED_SAVE.format(item_type=type(item)) + ) async def delete_all(graph_driver: GraphDriver, *items: object) -> None: @@ -154,7 +160,9 @@ async def delete_all(graph_driver: GraphDriver, *items: object) -> None: elif isinstance(item, (EntityEdge, EpisodicEdge, CommunityEdge)): await repo.delete_edge(item) else: - raise TypeError(f'Unsupported item type for delete_all: {type(item)!r}') + raise TypeError( + _MSG_UNSUPPORTED_DELETE.format(item_type=type(item)) + ) async def get_node_count(driver: GraphDriver, uuids: Iterable[str]) -> int: diff --git a/tests/integration/shared/fixtures_services.py b/tests/integration/shared/fixtures_services.py index 3ee3406..6244878 100644 --- a/tests/integration/shared/fixtures_services.py +++ b/tests/integration/shared/fixtures_services.py @@ -33,9 +33,9 @@ async def graph_driver(request: pytest.FixtureRequest): provider = request.param if provider is None: pytest.skip(_NO_PROVIDER_MSG) - marker = PROVIDER_MARKER.get(provider) - if marker: - request.node.add_marker(marker) + marker_name = PROVIDER_MARKER.get(provider) + if marker_name: + request.node.add_marker(getattr(pytest.mark, marker_name)) driver = await services.prepare_driver(provider) try: yield driver diff --git a/tests/unit/search/test_lucene_utils.py b/tests/unit/search/test_lucene_utils.py index 67761f5..a354a80 100644 --- a/tests/unit/search/test_lucene_utils.py +++ b/tests/unit/search/test_lucene_utils.py @@ -1,17 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # Modified by the Graphium project. +import pytest + from graphium_core.search.lucene import sanitize as lucene_sanitize -def test_lucene_sanitize(): - queries = [ +@pytest.mark.parametrize( + ('query', 'expected'), + [ ( 'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /', '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/', ), ('this has no escape characters', 'this has no escape characters'), - ] - - for query, expected in queries: - assert lucene_sanitize(query) == expected + ], +) +def test_lucene_sanitize(query: str, expected: str) -> None: + assert lucene_sanitize(query) == expected From ea4941ac0ead6a1111989e4befe63af08ffdd6c0 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:39:56 -0700 Subject: [PATCH 4/7] fix helper group id assertions --- tests/helpers/services.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/helpers/services.py b/tests/helpers/services.py index f8f5e0f..9bd1aec 100644 --- a/tests/helpers/services.py +++ b/tests/helpers/services.py @@ -215,7 +215,7 @@ async def assert_community_node_equals( await graph_driver.repositories.community_nodes.load_name_embedding(retrieved) assert retrieved.uuid == sample.uuid assert retrieved.name == sample.name - assert retrieved.group_id == GROUP_ID + assert retrieved.group_id == sample.group_id assert retrieved.created_at == sample.created_at assert retrieved.name_embedding is not None assert sample.name_embedding is not None @@ -226,7 +226,7 @@ async def assert_community_node_equals( async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode) -> None: assert retrieved.uuid == sample.uuid assert retrieved.name == sample.name - assert retrieved.group_id == GROUP_ID + assert retrieved.group_id == sample.group_id assert retrieved.created_at == sample.created_at assert retrieved.source == sample.source assert retrieved.source_description == sample.source_description From 835ca6ef213bd5bc255ecca2453827d32a9a6c33 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:50:43 -0700 Subject: [PATCH 5/7] use pytest_plugins in integration conftest --- tests/integration/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e9caed9..056ce5e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,4 +3,4 @@ from __future__ import annotations -from tests.integration.shared.fixtures_services import graph_driver # noqa: F401 +pytest_plugins = ['tests.integration.shared.fixtures_services'] From 3d69a2e588586983daa1091cb96fe611110b5797 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:51:26 -0700 Subject: [PATCH 6/7] document neptune default in test helper --- tests/helpers/services.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/helpers/services.py b/tests/helpers/services.py index 9bd1aec..d3c7598 100644 --- a/tests/helpers/services.py +++ b/tests/helpers/services.py @@ -6,7 +6,9 @@ These utilities centralise driver provisioning and common repository helpers so the individual test modules no longer need to import heavy fixtures directly. Nothing in this module triggers driver imports at import time; providers are -resolved lazily to keep unit collections lightweight. +resolved lazily to keep unit collections lightweight. Neptune support remains +disabled unless `DISABLE_NEPTUNE` is cleared in the environment before this +module is imported. """ from __future__ import annotations From 8d44edf3488ca5e673c309c0efbc61657bf33642 Mon Sep 17 00:00:00 2001 From: Luca Candela Date: Sat, 11 Oct 2025 22:53:39 -0700 Subject: [PATCH 7/7] document clean code lessons in AGENTS --- AGENTS.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index ef648e8..1b274da 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,15 @@ Author tests alongside features under `tests/`, naming files `test_.py` - Use `DISABLE_NEO4J=1` (and similar env vars) when running unit-only workflows. - `tests/evals/` contains end-to-end evaluation scripts. +### Clean Code & Test Hygiene +- Prefer `@pytest.mark.parametrize` over manual loops for clearer, isolated assertions. +- Seed random fixtures (embeddings, UUID helpers, etc.) so test runs are deterministic. +- Reuse shared constants (e.g., `GROUP_ID`) instead of hard-coding values inside tests. +- Register fixtures via `pytest_plugins = [...]` rather than importing them solely for side effects; drop unnecessary `noqa` directives once unused imports disappear. +- Tag fenced markdown blocks with an explicit language (` ```text `, ` ```ini `) to satisfy linting. +- When defaulting environment flags (e.g., `DISABLE_NEPTUNE`), document the behaviour and let caller-provided values take precedence. +- For provider-specific integration fixtures, parameterise with sentinels and skip inside the fixture so CI reports a skip instead of aborting when services are disabled. + ## Commit & Pull Request Guidelines Commits use an imperative, present-tense summary (for example, `add async cache invalidation`) optionally suffixed with the PR number as seen in history (`(#927)`). Squash fixups and keep unrelated changes isolated. Pull requests should include: a concise description, linked tracking issue, notes about schema or API impacts, and screenshots or logs when behavior changes. Confirm `make lint` and `make test` pass locally, and update docs or examples when public interfaces shift.