From d25a146028773474e5294d47f73ed24b4936020e Mon Sep 17 00:00:00 2001 From: LyriaClaw Date: Fri, 20 Feb 2026 08:16:20 +0000 Subject: [PATCH 1/2] docs: add two-pillar counterfactual on/off benchmark protocol --- README.md | 2 ++ docs/FULL_BENCHMARK_PLAN.md | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/README.md b/README.md index 13b9929..1c26597 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,8 @@ scripts/run_lancedb_vs_openclaw_mem_assisted.sh \ Deterministic long-run profile (stable run-group path for reproducible reruns): +For the counterfactual ON/OFF protocol (Pillar A execute now; Pillar B pre-registered), see `docs/FULL_BENCHMARK_PLAN.md#counterfactual-onoff-plan-for-the-two-pillars`. + ```bash scripts/run_phase_ab_longmemeval50.sh # writes to artifacts/phase-ab-compare/phase-ab-longmemeval50-seed7-topk10/ diff --git a/docs/FULL_BENCHMARK_PLAN.md b/docs/FULL_BENCHMARK_PLAN.md index f90ebfa..cd2bf9c 100644 --- a/docs/FULL_BENCHMARK_PLAN.md +++ b/docs/FULL_BENCHMARK_PLAN.md @@ -68,6 +68,44 @@ Required reporting: - If available from provider/tool payloads, include in compare artifact. - Current v0.2 retrieval reports do not provide tokenized cost telemetry; compare artifact records this as unavailable. +## Counterfactual ON/OFF plan for the two pillars + +This section adds a falsifiable ON/OFF design while keeping current Phase A/B priorities intact. + +### Scope and sequencing +- **Pillar A (execute now):** context pack contract hardening effects. +- **Pillar B (pre-register now, execute later):** learning-record/self-improving loop effects. +- Do not mix A and B rollout in the same implementation window. + +### Experimental arms +- `A0/B0`: baseline behavior (current pack contract, no learning block). +- `A1/B0`: Pillar A ON (contract hardening enabled). +- `A0/B1`: Pillar B ON (reserved; spec only until Pillar A gate passes). +- `A1/B1`: both ON (reserved for later confirmation run). + +For the current cycle, run only `A0/B0` vs `A1/B0`. + +### Metric definitions (anti-gaming, explicit) +- **Recall@K / Precision@K / nDCG@K**: + - `K` must be fixed per run and written in manifest (`top_k`, default `10`). + - relevance source must be dataset `relevant_session_ids` (no post-hoc re-labeling). +- **Citation coverage**: + - `1 - (included_without_citation_count / included_count)`. + - also report numerator/denominator explicitly. +- **Rationale coverage**: + - `1 - (included_without_reason_count / included_count)`. +- **Determinism pass rate**: + - For each fixed DB/query case, run 5 repeats and compare canonicalized JSON (excluding timestamp fields only). + - pass rate denominator must be number of distinct cases, not total runs. +- **Budget exclusion rate**: + - `excluded_by_budget / total_candidates`. +- **Latency p50/p95**: + - measured on the retrieval+pack path under the same runner/hardware profile recorded in manifest. + +### Decision posture +- Treat this as a non-regression gate first (quality and determinism). +- Only after Pillar A gate passes should Pillar B execution be scheduled. + ## Reproducibility contract Each run package must include a manifest with: From 7be0562d9e4ba2d05abd3f264322cac05b890f60 Mon Sep 17 00:00:00 2001 From: LyriaClaw Date: Fri, 20 Feb 2026 08:38:25 +0000 Subject: [PATCH 2/2] chore: ruff format + fix missing os import --- scripts/hybrid_two_stage_from_reports.py | 12 +++- .../run_lancedb_vs_openclaw_mem_assisted.py | 48 +++++++++++---- scripts/run_longmemeval50_qa_compare.py | 55 +++++++++++------ .../run_longmemeval50_qa_compare_openclaw.py | 55 +++++++++++++---- scripts/run_memory_core_vs_openclaw_mem.py | 9 ++- scripts/run_memory_triplet_comprehensive.py | 28 ++++++--- scripts/run_two_plugin_baseline.py | 44 ++++++++++---- scripts/summarize_phase_ab_multiseed.py | 30 +++++++--- .../adapters/memory_core.py | 6 +- .../adapters/memory_lancedb.py | 4 +- .../adapters/openclaw_mem.py | 6 +- src/openclaw_memory_bench/cli.py | 22 +++++-- src/openclaw_memory_bench/converters.py | 12 +++- src/openclaw_memory_bench/gateway_client.py | 8 ++- src/openclaw_memory_bench/hybrid.py | 16 +++-- src/openclaw_memory_bench/metrics.py | 4 +- src/openclaw_memory_bench/runner.py | 8 ++- src/openclaw_memory_bench/validation.py | 59 +++++++++++++++---- tests/test_hybrid_fusion.py | 14 ++++- tests/test_memu_adapter.py | 2 +- tests/test_phase_ab_compare_runner.py | 25 ++++++-- tests/test_qmd_adapter.py | 12 +++- 22 files changed, 365 insertions(+), 114 deletions(-) diff --git a/scripts/hybrid_two_stage_from_reports.py b/scripts/hybrid_two_stage_from_reports.py index 12e5e5c..c52febd 100644 --- a/scripts/hybrid_two_stage_from_reports.py +++ b/scripts/hybrid_two_stage_from_reports.py @@ -8,7 +8,9 @@ def main() -> int: - ap = argparse.ArgumentParser(description="Build a two-stage hybrid retrieval report from two reports.") + ap = argparse.ArgumentParser( + description="Build a two-stage hybrid retrieval report from two reports." + ) ap.add_argument("--must-report", required=True) ap.add_argument("--fallback-report", required=True) ap.add_argument("--run-id", required=True) @@ -33,7 +35,9 @@ def main() -> int: must_report = load_report(must_path) fallback_report = load_report(fallback_path) - stage2_max_ms = float(args.stage2_max_ms) if args.stage2_max_ms and args.stage2_max_ms > 0 else None + stage2_max_ms = ( + float(args.stage2_max_ms) if args.stage2_max_ms and args.stage2_max_ms > 0 else None + ) manifest = { "experiment": { @@ -62,7 +66,9 @@ def main() -> int: ) report_path = out_dir / "retrieval-report.json" - report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + report_path.write_text( + json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) md_lines = [ f"# Two-stage hybrid report ({args.run_id})", diff --git a/scripts/run_lancedb_vs_openclaw_mem_assisted.py b/scripts/run_lancedb_vs_openclaw_mem_assisted.py index 87396ea..ea59537 100644 --- a/scripts/run_lancedb_vs_openclaw_mem_assisted.py +++ b/scripts/run_lancedb_vs_openclaw_mem_assisted.py @@ -62,7 +62,9 @@ def _session_importance_label(session: dict[str, Any]) -> str: return "ignore" # Fallback lexical proxy when dataset lacks labels. - merged = "\n".join(str(m.get("content") or "") for m in session.get("messages", []) if isinstance(m, dict)).lower() + merged = "\n".join( + str(m.get("content") or "") for m in session.get("messages", []) if isinstance(m, dict) + ).lower() must_kw = ( "must remember", "important", @@ -370,7 +372,9 @@ def _metric_pack(report: dict[str, Any]) -> dict[str, float]: } -def _win_eval(*, baseline: dict[str, float], candidate: dict[str, float], policy: str) -> dict[str, Any]: +def _win_eval( + *, baseline: dict[str, float], candidate: dict[str, float], policy: str +) -> dict[str, Any]: p95_gain = ( (baseline["search_ms_p95"] - candidate["search_ms_p95"]) / baseline["search_ms_p95"] if baseline["search_ms_p95"] @@ -493,7 +497,9 @@ def main() -> int: args = ap.parse_args() repo_root = Path(__file__).resolve().parents[1] - dataset_path = repo_root / args.dataset if not Path(args.dataset).is_absolute() else Path(args.dataset) + dataset_path = ( + repo_root / args.dataset if not Path(args.dataset).is_absolute() else Path(args.dataset) + ) out_root = repo_root / args.output_root run_group = _resolve_run_group(explicit_run_group=args.run_group, run_label=args.run_label) @@ -524,7 +530,9 @@ def main() -> int: if args.include_observational: obs_dataset, obs_stats = _compress_dataset_observational(raw) obs_path = run_dir / "derived-dataset-observational.json" - obs_path.write_text(json.dumps(obs_dataset, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + obs_path.write_text( + json.dumps(obs_dataset, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) obs_report = _run_lancedb( dataset_path=obs_path, @@ -561,7 +569,9 @@ def main() -> int: for policy in args.policies: filtered, filter_stats = _filter_dataset(raw, policy=policy) filtered_path = run_dir / f"derived-dataset-{policy}.json" - filtered_path.write_text(json.dumps(filtered, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + filtered_path.write_text( + json.dumps(filtered, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) report = _run_lancedb( dataset_path=filtered_path, @@ -624,7 +634,11 @@ def main() -> int: "fallback_report_path": fallback_candidate["report"]["report_path"], "mode": "must_count_gate", "fusion_mode": args.hybrid_fusion_mode, - "k_rrf": (float(args.hybrid_k_rrf) if args.hybrid_fusion_mode == "rrf_fusion" else None), + "k_rrf": ( + float(args.hybrid_k_rrf) + if args.hybrid_fusion_mode == "rrf_fusion" + else None + ), "min_must_count": int(args.hybrid_min_must_count), "stage2_max_additional": int(args.hybrid_stage2_max_additional), "stage2_max_ms": stage2_max_ms, @@ -640,7 +654,9 @@ def main() -> int: hybrid_dir = run_dir / "hybrid" hybrid_dir.mkdir(parents=True, exist_ok=True) hybrid_json = hybrid_dir / "retrieval-report.json" - hybrid_json.write_text(json.dumps(hybrid_report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + hybrid_json.write_text( + json.dumps(hybrid_report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) hybrid_md = hybrid_dir / "retrieval-report.md" _write_hybrid_markdown( path=hybrid_md, @@ -702,10 +718,16 @@ def main() -> int: # Pass if p95 improves >=20% while recall drop <=3pp and nDCG non-negative. wins: list[dict[str, Any]] = [] for row in curve: - wins.append(_win_eval(baseline=row["baseline"], candidate=row["experimental"], policy=row["policy"])) + wins.append( + _win_eval(baseline=row["baseline"], candidate=row["experimental"], policy=row["policy"]) + ) if hybrid_tradeoff is not None: - wins.append(_win_eval(baseline=baseline_metrics, candidate=hybrid_tradeoff["hybrid"], policy="hybrid")) + wins.append( + _win_eval( + baseline=baseline_metrics, candidate=hybrid_tradeoff["hybrid"], policy="hybrid" + ) + ) arms: dict[str, Any] = { "baseline": baseline, @@ -750,7 +772,9 @@ def main() -> int: } compare_json = run_dir / f"compare-{run_group}.json" - compare_json.write_text(json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + compare_json.write_text( + json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) lines = [ f"# Phase A/B compare ({run_group})", @@ -859,7 +883,9 @@ def main() -> int: "compare_md": str(compare_md), "latest_pointer": str(latest_pointer), "baseline_report": baseline["report_path"], - "observational_report": (observational["report"]["report_path"] if observational is not None else None), + "observational_report": ( + observational["report"]["report_path"] if observational is not None else None + ), "experimental_reports": [x["report"]["report_path"] for x in candidates], "hybrid_report": (hybrid["report"]["report_path"] if hybrid is not None else None), }, diff --git a/scripts/run_longmemeval50_qa_compare.py b/scripts/run_longmemeval50_qa_compare.py index b378564..a77c8f4 100755 --- a/scripts/run_longmemeval50_qa_compare.py +++ b/scripts/run_longmemeval50_qa_compare.py @@ -96,12 +96,7 @@ def openai_chat_completions( with urllib.request.urlopen(req, timeout=timeout_s) as resp: raw = resp.read().decode("utf-8") data = json.loads(raw) - return ( - data.get("choices", [{}])[0] - .get("message", {}) - .get("content", "") - .strip() - ) + return data.get("choices", [{}])[0].get("message", {}).get("content", "").strip() except urllib.error.HTTPError as e: # Read body for debugging (keep local) try: @@ -275,13 +270,20 @@ def main() -> int: ap.add_argument("--judge-model", default="") ap.add_argument("--limit", type=int, default=20, help="question limit (default 20 for Phase A)") ap.add_argument("--seed", type=int, default=7) - ap.add_argument("--arms", nargs="+", default=["oracle", "observational"], choices=["oracle", "full", "observational"]) + ap.add_argument( + "--arms", + nargs="+", + default=["oracle", "observational"], + choices=["oracle", "full", "observational"], + ) ap.add_argument("--max-msg-chars", type=int, default=600) args = ap.parse_args() api_key = os.getenv("OPENAI_API_KEY") or "" if not api_key.strip(): - raise SystemExit("OPENAI_API_KEY is missing/empty. Set it in the environment before running.") + raise SystemExit( + "OPENAI_API_KEY is missing/empty. Set it in the environment before running." + ) judge_model = args.judge_model.strip() or args.model @@ -311,7 +313,9 @@ def main() -> int: "created_at": datetime.now(UTC).isoformat(), "note": "Phase A QA compare on repo-local longmemeval-50 format (not official LongMemEval runner).", } - (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + (out_dir / "manifest.json").write_text( + json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) summary: dict[str, Any] = {"manifest": manifest, "arms": {}} @@ -323,7 +327,10 @@ def main() -> int: rows: list[Row] = [] - with hyp_path.open("w", encoding="utf-8") as hyp_f, eval_path.open("w", encoding="utf-8") as eval_f: + with ( + hyp_path.open("w", encoding="utf-8") as hyp_f, + eval_path.open("w", encoding="utf-8") as eval_f, + ): for i, q in enumerate(qs): qid = str(q.get("question_id") or "") qtype = str(q.get("question_type") or "") @@ -333,7 +340,9 @@ def main() -> int: rel_ids = set(str(x) for x in (q.get("relevant_session_ids") or []) if str(x)) if arm == "oracle": - arm_sessions = [s for s in sessions if str(s.get("session_id") or "") in rel_ids] + arm_sessions = [ + s for s in sessions if str(s.get("session_id") or "") in rel_ids + ] else: arm_sessions = sessions @@ -352,11 +361,16 @@ def main() -> int: max_tokens=256, ) - print(json.dumps({"question_id": qid, "hypothesis": hyp}, ensure_ascii=False), file=hyp_f) + print( + json.dumps({"question_id": qid, "hypothesis": hyp}, ensure_ascii=False), + file=hyp_f, + ) # Judge abstention = "_abs" in qid - judge_prompt = get_anscheck_prompt(qtype, question, answer, hyp, abstention=abstention) + judge_prompt = get_anscheck_prompt( + qtype, question, answer, hyp, abstention=abstention + ) _jitter_sleep() judge_resp = openai_chat_completions( @@ -384,7 +398,7 @@ def main() -> int: rows.append(Row(qid, qtype, question, answer, hyp, bool(label))) # Progress line - print(f"[{arm}] {i+1}/{len(qs)} qid={qid} label={label}") + print(f"[{arm}] {i + 1}/{len(qs)} qid={qid} label={label}") # Summarize by_type: dict[str, list[bool]] = {} @@ -398,13 +412,20 @@ def acc(xs: Iterable[bool]) -> float: arm_summary = { "n": len(rows), "accuracy": acc([r.label for r in rows]), - "by_question_type": {k: {"accuracy": acc(v), "n": len(v)} for k, v in sorted(by_type.items())}, - "paths": {"hypotheses": str(hyp_path.relative_to(REPO_ROOT)), "eval": str(eval_path.relative_to(REPO_ROOT))}, + "by_question_type": { + k: {"accuracy": acc(v), "n": len(v)} for k, v in sorted(by_type.items()) + }, + "paths": { + "hypotheses": str(hyp_path.relative_to(REPO_ROOT)), + "eval": str(eval_path.relative_to(REPO_ROOT)), + }, } summary["arms"][arm] = arm_summary # Write summary - (out_dir / "summary.json").write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + (out_dir / "summary.json").write_text( + json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) # Markdown md_lines: list[str] = [] diff --git a/scripts/run_longmemeval50_qa_compare_openclaw.py b/scripts/run_longmemeval50_qa_compare_openclaw.py index 574581a..24e8793 100755 --- a/scripts/run_longmemeval50_qa_compare_openclaw.py +++ b/scripts/run_longmemeval50_qa_compare_openclaw.py @@ -77,7 +77,11 @@ def openclaw_agent_once( ] proc = subprocess.run(cmd, capture_output=True, text=True) if proc.returncode != 0: - raise RuntimeError(proc.stderr.strip() or proc.stdout.strip() or f"openclaw agent failed ({proc.returncode})") + raise RuntimeError( + proc.stderr.strip() + or proc.stdout.strip() + or f"openclaw agent failed ({proc.returncode})" + ) data = json.loads(proc.stdout) payloads = (((data or {}).get("result") or {}).get("payloads")) or [] @@ -89,7 +93,9 @@ def openclaw_agent_once( return "\n".join(texts).strip() -def get_anscheck_prompt(task: str, question: str, answer: str, response: str, *, abstention: bool = False) -> str: +def get_anscheck_prompt( + task: str, question: str, answer: str, response: str, *, abstention: bool = False +) -> str: # Adapted from LongMemEval src/evaluation/evaluate_qa.py if not abstention: if task in {"single-session-user", "single-session-assistant", "multi-session"}: @@ -219,7 +225,12 @@ def main() -> int: ap.add_argument("--run-group", default=f"{_now_tag()}-longmemeval50-qa-openclaw") ap.add_argument("--limit", type=int, default=20) ap.add_argument("--seed", type=int, default=7) - ap.add_argument("--arms", nargs="+", default=["oracle", "observational"], choices=["oracle", "full", "observational"]) + ap.add_argument( + "--arms", + nargs="+", + default=["oracle", "observational"], + choices=["oracle", "full", "observational"], + ) ap.add_argument("--max-msg-chars", type=int, default=600) ap.add_argument("--thinking", default="high") args = ap.parse_args() @@ -249,7 +260,9 @@ def main() -> int: "created_at": datetime.now(UTC).isoformat(), "note": "Uses openclaw agent (Gateway) for actor+judge; pacing enforced via sleeps.", } - (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + (out_dir / "manifest.json").write_text( + json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) summary: dict[str, Any] = {"manifest": manifest, "arms": {}} @@ -261,7 +274,10 @@ def main() -> int: rows: list[Row] = [] - with hyp_path.open("w", encoding="utf-8") as hyp_f, eval_path.open("w", encoding="utf-8") as eval_f: + with ( + hyp_path.open("w", encoding="utf-8") as hyp_f, + eval_path.open("w", encoding="utf-8") as eval_f, + ): for i, q in enumerate(qs): qid = str(q.get("question_id") or "") qtype = str(q.get("question_type") or "") @@ -271,7 +287,9 @@ def main() -> int: rel_ids = set(str(x) for x in (q.get("relevant_session_ids") or []) if str(x)) if arm == "oracle": - arm_sessions = [s for s in sessions if str(s.get("session_id") or "") in rel_ids] + arm_sessions = [ + s for s in sessions if str(s.get("session_id") or "") in rel_ids + ] else: arm_sessions = sessions @@ -290,12 +308,18 @@ def main() -> int: message=actor_message(history=history, question=question), thinking=args.thinking, ) - print(json.dumps({"question_id": qid, "hypothesis": hyp}, ensure_ascii=False), file=hyp_f, flush=True) + print( + json.dumps({"question_id": qid, "hypothesis": hyp}, ensure_ascii=False), + file=hyp_f, + flush=True, + ) hyp_f.flush() # Judge abstention = "_abs" in qid - judge_prompt = get_anscheck_prompt(qtype, question, answer, hyp, abstention=abstention) + judge_prompt = get_anscheck_prompt( + qtype, question, answer, hyp, abstention=abstention + ) _sleep_jitter(rng) judge_resp = openclaw_agent_once( @@ -318,7 +342,7 @@ def main() -> int: eval_f.flush() rows.append(Row(qid, qtype, bool(label))) - print(f"[{arm}] {i+1}/{len(qs)} qid={qid} label={label}", flush=True) + print(f"[{arm}] {i + 1}/{len(qs)} qid={qid} label={label}", flush=True) by_type: dict[str, list[bool]] = {} for r in rows: @@ -331,12 +355,19 @@ def acc(xs: Iterable[bool]) -> float: arm_summary = { "n": len(rows), "accuracy": acc([r.label for r in rows]), - "by_question_type": {k: {"accuracy": acc(v), "n": len(v)} for k, v in sorted(by_type.items())}, - "paths": {"hypotheses": str(hyp_path.relative_to(REPO_ROOT)), "eval": str(eval_path.relative_to(REPO_ROOT))}, + "by_question_type": { + k: {"accuracy": acc(v), "n": len(v)} for k, v in sorted(by_type.items()) + }, + "paths": { + "hypotheses": str(hyp_path.relative_to(REPO_ROOT)), + "eval": str(eval_path.relative_to(REPO_ROOT)), + }, } summary["arms"][arm] = arm_summary - (out_dir / "summary.json").write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") + (out_dir / "summary.json").write_text( + json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) md_lines: list[str] = [] md_lines.append(f"# LongMemEval-50 QA compare (Phase A, OpenClaw) — {run_group}\n") diff --git a/scripts/run_memory_core_vs_openclaw_mem.py b/scripts/run_memory_core_vs_openclaw_mem.py index 12ddec4..2475557 100644 --- a/scripts/run_memory_core_vs_openclaw_mem.py +++ b/scripts/run_memory_core_vs_openclaw_mem.py @@ -2,6 +2,7 @@ import argparse import json +import os import re from datetime import UTC, datetime from pathlib import Path @@ -98,7 +99,9 @@ def main() -> int: args = ap.parse_args() repo_root = Path(__file__).resolve().parents[1] - dataset_path = repo_root / args.dataset if not Path(args.dataset).is_absolute() else Path(args.dataset) + dataset_path = ( + repo_root / args.dataset if not Path(args.dataset).is_absolute() else Path(args.dataset) + ) out_root = repo_root / args.output_root run_group = f"{_now_tag()}-{_slug(args.run_label)}" @@ -161,7 +164,9 @@ def main() -> int: } compare_json = run_dir / f"compare-{run_group}.json" - compare_json.write_text(json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + compare_json.write_text( + json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) lines = [ f"# Sidecar compare ({run_group})", diff --git a/scripts/run_memory_triplet_comprehensive.py b/scripts/run_memory_triplet_comprehensive.py index 43bacdf..c0cfe7e 100644 --- a/scripts/run_memory_triplet_comprehensive.py +++ b/scripts/run_memory_triplet_comprehensive.py @@ -41,7 +41,9 @@ def _log(msg: str, progress_log: Path | None) -> None: fh.write(line + "\n") -def _prepare_dataset(*, benchmark: str, limit: int | None, out: Path) -> tuple[Path, dict[str, Any]]: +def _prepare_dataset( + *, benchmark: str, limit: int | None, out: Path +) -> tuple[Path, dict[str, Any]]: data = convert_benchmark(benchmark, limit=limit) validate_dataset_payload(data) @@ -251,8 +253,12 @@ def _metric_pack(report: dict[str, Any]) -> dict[str, float]: def main() -> int: - ap = argparse.ArgumentParser(description="Comprehensive triplet benchmark: memory-core, memory-lancedb, openclaw-mem") - ap.add_argument("--benchmark", default="longmemeval", choices=["locomo", "longmemeval", "convomem"]) + ap = argparse.ArgumentParser( + description="Comprehensive triplet benchmark: memory-core, memory-lancedb, openclaw-mem" + ) + ap.add_argument( + "--benchmark", default="longmemeval", choices=["locomo", "longmemeval", "convomem"] + ) ap.add_argument("--dataset-limit", type=int, default=100) ap.add_argument("--question-limit", type=int, default=100) ap.add_argument("--top-k", type=int, default=10) @@ -294,7 +300,11 @@ def main() -> int: ) dataset = load_retrieval_dataset(dataset_path) - effective_questions = min(args.question_limit, len(dataset.questions)) if args.question_limit is not None else len(dataset.questions) + effective_questions = ( + min(args.question_limit, len(dataset.questions)) + if args.question_limit is not None + else len(dataset.questions) + ) (run_dir / "dataset.meta.json").write_text( json.dumps(dataset_meta, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" @@ -420,10 +430,12 @@ def main() -> int: "progress_log": str(progress_log), "metrics": metrics, "delta_openclaw_mem_minus_memory_core": { - k: metrics["openclaw-mem"][k] - metrics["memory-core"][k] for k in metrics["memory-core"] + k: metrics["openclaw-mem"][k] - metrics["memory-core"][k] + for k in metrics["memory-core"] }, "delta_lancedb_minus_memory_core": { - k: metrics["memory-lancedb"][k] - metrics["memory-core"][k] for k in metrics["memory-core"] + k: metrics["memory-lancedb"][k] - metrics["memory-core"][k] + for k in metrics["memory-core"] }, "delta_openclaw_mem_minus_lancedb": { k: metrics["openclaw-mem"][k] - metrics["memory-lancedb"][k] @@ -432,7 +444,9 @@ def main() -> int: } compare_json = run_dir / f"compare-{run_group}.json" - compare_json.write_text(json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + compare_json.write_text( + json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) lines = [ f"# Comprehensive triplet report ({run_group})", diff --git a/scripts/run_two_plugin_baseline.py b/scripts/run_two_plugin_baseline.py index 0189262..ba97c3c 100644 --- a/scripts/run_two_plugin_baseline.py +++ b/scripts/run_two_plugin_baseline.py @@ -30,7 +30,9 @@ def _read_profile(path: Path) -> dict[str, Any]: return data -def _prepare_dataset(profile: dict[str, Any], dataset_limit_override: int | None) -> tuple[Path, dict[str, Any]]: +def _prepare_dataset( + profile: dict[str, Any], dataset_limit_override: int | None +) -> tuple[Path, dict[str, Any]]: bench = str(profile["benchmark"]) ds = profile.get("dataset") or {} @@ -155,12 +157,15 @@ def _metrics(rep: dict[str, Any]) -> dict[str, float]: "delta_openclaw_minus_memu": delta, "comparability": { "same_top_k": openclaw_report["top_k"] == memu_report["top_k"], - "same_question_count": openclaw_report["summary"]["questions_total"] == memu_report["summary"]["questions_total"], + "same_question_count": openclaw_report["summary"]["questions_total"] + == memu_report["summary"]["questions_total"], }, } compare_json = run_dir / f"compare-{run_group}.json" - compare_json.write_text(json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + compare_json.write_text( + json.dumps(compare, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) lines = [ f"# Two-plugin compare report ({run_group})", @@ -199,20 +204,35 @@ def _metrics(rep: dict[str, Any]) -> dict[str, float]: def main() -> int: - ap = argparse.ArgumentParser(description="Run two-plugin retrieval baseline and generate compare artifacts") + ap = argparse.ArgumentParser( + description="Run two-plugin retrieval baseline and generate compare artifacts" + ) ap.add_argument("--profile", default="configs/run-profiles/two-plugin-baseline.json") - ap.add_argument("--dataset-limit", type=int, default=None, help="override dataset conversion limit") - ap.add_argument("--question-limit", type=int, default=None, help="override benchmark question count used in run") + ap.add_argument( + "--dataset-limit", type=int, default=None, help="override dataset conversion limit" + ) + ap.add_argument( + "--question-limit", + type=int, + default=None, + help="override benchmark question count used in run", + ) ap.add_argument("--run-label", default=None, help="optional suffix for run group") ap.add_argument("--gateway-url", default=None, help="override memu gateway url") ap.add_argument("--gateway-token", default=None, help="override memu gateway token") args = ap.parse_args() repo_root = Path(__file__).resolve().parents[1] - profile_path = (repo_root / args.profile).resolve() if not Path(args.profile).is_absolute() else Path(args.profile) + profile_path = ( + (repo_root / args.profile).resolve() + if not Path(args.profile).is_absolute() + else Path(args.profile) + ) profile = _read_profile(profile_path) - out_root = repo_root / str((profile.get("output") or {}).get("root") or "artifacts/full-benchmark") + out_root = repo_root / str( + (profile.get("output") or {}).get("root") or "artifacts/full-benchmark" + ) run_group = _now_tag() + ("-" + _slug(args.run_label) if args.run_label else "") run_dir = out_root / run_group run_dir.mkdir(parents=True, exist_ok=True) @@ -222,7 +242,9 @@ def main() -> int: retrieval = profile.get("retrieval") or {} top_k = int(retrieval.get("top_k") or 10) - question_limit = args.question_limit if args.question_limit is not None else retrieval.get("question_limit") + question_limit = ( + args.question_limit if args.question_limit is not None else retrieval.get("question_limit") + ) question_limit = int(question_limit) if question_limit is not None else None (run_dir / "profile.lock.json").write_text( @@ -298,7 +320,9 @@ def main() -> int: first_err = None if memu_report["failures"]: first_err = memu_report["failures"][0].get("error") - raise RuntimeError(first_err or "memu preferred mode failed without successful questions") + raise RuntimeError( + first_err or "memu preferred mode failed without successful questions" + ) except Exception as e: # noqa: BLE001 if not allow_fallback: raise diff --git a/scripts/summarize_phase_ab_multiseed.py b/scripts/summarize_phase_ab_multiseed.py index 1208a3c..6d6c2b7 100644 --- a/scripts/summarize_phase_ab_multiseed.py +++ b/scripts/summarize_phase_ab_multiseed.py @@ -177,8 +177,10 @@ def _extract_policy_block(compare: dict[str, Any], policy: str) -> dict[str, Any "hit_at_k": float(esum["hit_at_k"]) - float(bsum["hit_at_k"]), "mrr": float(esum["mrr"]) - float(bsum["mrr"]), "ndcg_at_k": float(esum["ndcg_at_k"]) - float(bsum["ndcg_at_k"]), - "search_ms_p50": float(esum.get("search_ms_p50", 0.0)) - float(bsum.get("search_ms_p50", 0.0)), - "search_ms_p95": float(esum.get("search_ms_p95", 0.0)) - float(bsum.get("search_ms_p95", 0.0)), + "search_ms_p50": float(esum.get("search_ms_p50", 0.0)) + - float(bsum.get("search_ms_p50", 0.0)), + "search_ms_p95": float(esum.get("search_ms_p95", 0.0)) + - float(bsum.get("search_ms_p95", 0.0)), }, } @@ -278,7 +280,9 @@ def main() -> int: "search_ms_p95", ): vals = [r.policies[policy]["delta"][key] for r in rows] - block["delta"][key] = bootstrap_mean_ci(vals, n=args.bootstrap_n, seed=args.bootstrap_seed) + block["delta"][key] = bootstrap_mean_ci( + vals, n=args.bootstrap_n, seed=args.bootstrap_seed + ) out["policies"][policy] = block @@ -310,13 +314,23 @@ def fmt_ci(d: dict[str, float]) -> str: lines.append(f"- latency p95(ms): {fmt_ci(out['baseline']['search_ms_p95'])}") for policy in ("must", "must+nice"): - lines.append(f"\n## Experimental policy = {policy} (Δ experimental - baseline; mean [95% CI] over seeds)") - lines.append(f"- compression items: {fmt_ci(out['policies'][policy]['compression_ratio_items'])}") - lines.append(f"- compression chars: {fmt_ci(out['policies'][policy]['compression_ratio_chars'])}") + lines.append( + f"\n## Experimental policy = {policy} (Δ experimental - baseline; mean [95% CI] over seeds)" + ) + lines.append( + f"- compression items: {fmt_ci(out['policies'][policy]['compression_ratio_items'])}" + ) + lines.append( + f"- compression chars: {fmt_ci(out['policies'][policy]['compression_ratio_chars'])}" + ) for key in ("hit_at_k", "precision_at_k", "recall_at_k", "mrr", "ndcg_at_k"): lines.append(f"- Δ {key}: {fmt_ci(out['policies'][policy]['delta'][key])}") - lines.append(f"- Δ latency p50(ms): {fmt_ci(out['policies'][policy]['delta']['search_ms_p50'])}") - lines.append(f"- Δ latency p95(ms): {fmt_ci(out['policies'][policy]['delta']['search_ms_p95'])}") + lines.append( + f"- Δ latency p50(ms): {fmt_ci(out['policies'][policy]['delta']['search_ms_p50'])}" + ) + lines.append( + f"- Δ latency p95(ms): {fmt_ci(out['policies'][policy]['delta']['search_ms_p95'])}" + ) lines.append("\n## Notes") for n in out["notes"]: diff --git a/src/openclaw_memory_bench/adapters/memory_core.py b/src/openclaw_memory_bench/adapters/memory_core.py index af98a77..c8bf583 100644 --- a/src/openclaw_memory_bench/adapters/memory_core.py +++ b/src/openclaw_memory_bench/adapters/memory_core.py @@ -281,7 +281,11 @@ def await_indexing(self, ingest_result: dict, container_tag: str) -> None: def _session_id_from_row(path: str | None, snippet: str | None) -> str | None: if path: name = Path(path).name - if name.startswith("session-") and name.endswith(".md") and "container_tag:" not in (snippet or ""): + if ( + name.startswith("session-") + and name.endswith(".md") + and "container_tag:" not in (snippet or "") + ): # Legacy filename format: session-.md return name[len("session-") : -len(".md")] diff --git a/src/openclaw_memory_bench/adapters/memory_lancedb.py b/src/openclaw_memory_bench/adapters/memory_lancedb.py index fba3002..dd7c548 100644 --- a/src/openclaw_memory_bench/adapters/memory_lancedb.py +++ b/src/openclaw_memory_bench/adapters/memory_lancedb.py @@ -60,7 +60,9 @@ def _extract_memories(result: Any) -> list[dict]: return [] def _invoke(self, tool: str, args: dict) -> Any: - return invoke_tool(tool=tool, tool_args=args, session_key=self.session_key, config=self.config) + return invoke_tool( + tool=tool, tool_args=args, session_key=self.session_key, config=self.config + ) def clear(self, container_tag: str) -> None: ids = list(self._container_ids.get(container_tag, [])) diff --git a/src/openclaw_memory_bench/adapters/openclaw_mem.py b/src/openclaw_memory_bench/adapters/openclaw_mem.py index 792653e..6baad74 100644 --- a/src/openclaw_memory_bench/adapters/openclaw_mem.py +++ b/src/openclaw_memory_bench/adapters/openclaw_mem.py @@ -138,7 +138,11 @@ def ingest(self, sessions: list[Session], container_tag: str) -> dict: ) out = self._run(cmd) result = json.loads(out) - return {"document_ids": result.get("ids", []), "container_tag": container_tag, "db_path": db_path} + return { + "document_ids": result.get("ids", []), + "container_tag": container_tag, + "db_path": db_path, + } def await_indexing(self, ingest_result: dict, container_tag: str) -> None: # local SQLite/FTS path is effectively immediate for now diff --git a/src/openclaw_memory_bench/cli.py b/src/openclaw_memory_bench/cli.py index 3e2f533..a611be7 100644 --- a/src/openclaw_memory_bench/cli.py +++ b/src/openclaw_memory_bench/cli.py @@ -48,7 +48,11 @@ def cmd_plan(args: argparse.Namespace) -> int: out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(manifest, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") - print(json.dumps({"ok": True, "manifest": str(out), "run_id": run_id}, ensure_ascii=False, indent=2)) + print( + json.dumps( + {"ok": True, "manifest": str(out), "run_id": run_id}, ensure_ascii=False, indent=2 + ) + ) return 0 @@ -197,7 +201,9 @@ def build_parser() -> argparse.ArgumentParser: plan.add_argument("--out", default="artifacts/run-manifest.json") plan.set_defaults(func=cmd_plan) - prep = sub.add_parser("prepare-dataset", help="Download and convert canonical benchmark dataset") + prep = sub.add_parser( + "prepare-dataset", help="Download and convert canonical benchmark dataset" + ) prep.add_argument("--benchmark", required=True, choices=["locomo", "longmemeval", "convomem"]) prep.add_argument("--limit", type=int, default=None, help="Limit number of converted questions") prep.add_argument("--out", required=True, help="Output retrieval dataset JSON path") @@ -246,7 +252,9 @@ def build_parser() -> argparse.ArgumentParser: help="(openclaw-mem) explicit command base override, e.g. openclaw-mem", ) run.add_argument("--out", default=None, help="Output report path") - run.add_argument("--skip-ingest", action="store_true", help="Skip adapter ingest and search existing memory") + run.add_argument( + "--skip-ingest", action="store_true", help="Skip adapter ingest and search existing memory" + ) run.add_argument( "--preindex-once", action="store_true", @@ -304,8 +312,12 @@ def build_parser() -> argparse.ArgumentParser: ) # memu-engine / gateway options - run.add_argument("--gateway-url", default=None, help="Gateway base URL (default from local config)") - run.add_argument("--gateway-token", default=None, help="Gateway token (default from env/config)") + run.add_argument( + "--gateway-url", default=None, help="Gateway base URL (default from local config)" + ) + run.add_argument( + "--gateway-token", default=None, help="Gateway token (default from env/config)" + ) run.add_argument("--agent-id", default="main", help="x-openclaw-agent-id header") run.add_argument("--session-key", default="main", help="sessionKey for tools/invoke") run.add_argument( diff --git a/src/openclaw_memory_bench/converters.py b/src/openclaw_memory_bench/converters.py index dcd785c..154aa8f 100644 --- a/src/openclaw_memory_bench/converters.py +++ b/src/openclaw_memory_bench/converters.py @@ -176,7 +176,9 @@ def convert_convomem(*, limit: int | None = None) -> dict: sid = f"{qid}-session-{ci}" msgs = [ _message( - "user" if str(m.get("speaker", "")).lower() == "user" else "assistant", + "user" + if str(m.get("speaker", "")).lower() == "user" + else "assistant", str(m.get("text") or ""), ) for m in (conv.get("messages") or []) @@ -184,10 +186,14 @@ def convert_convomem(*, limit: int | None = None) -> dict: sessions.append({"session_id": sid, "messages": msgs, "metadata": {}}) evidence_texts = { - str(x.get("text") or "") for x in (ev.get("message_evidences") or []) if x.get("text") + str(x.get("text") or "") + for x in (ev.get("message_evidences") or []) + if x.get("text") } relevant = [ - s["session_id"] for s in sessions if _session_has_evidence_messages(s, evidence_texts) + s["session_id"] + for s in sessions + if _session_has_evidence_messages(s, evidence_texts) ] if not relevant and sessions: relevant = [sessions[0]["session_id"]] diff --git a/src/openclaw_memory_bench/gateway_client.py b/src/openclaw_memory_bench/gateway_client.py index 8b3fafb..67cf2f7 100644 --- a/src/openclaw_memory_bench/gateway_client.py +++ b/src/openclaw_memory_bench/gateway_client.py @@ -48,11 +48,15 @@ def resolve_gateway_config(overrides: dict | None = None) -> dict[str, str]: } -def invoke_tool(*, tool: str, tool_args: dict, session_key: str = "main", config: dict | None = None) -> Any: +def invoke_tool( + *, tool: str, tool_args: dict, session_key: str = "main", config: dict | None = None +) -> Any: resolved = resolve_gateway_config(config) token = resolved["gateway_token"] if not token: - raise RuntimeError("Gateway token is required (OPENCLAW_GATEWAY_TOKEN or ~/.openclaw/openclaw.json)") + raise RuntimeError( + "Gateway token is required (OPENCLAW_GATEWAY_TOKEN or ~/.openclaw/openclaw.json)" + ) url = resolved["gateway_url"] + "/tools/invoke" payload = { diff --git a/src/openclaw_memory_bench/hybrid.py b/src/openclaw_memory_bench/hybrid.py index 0cb0ecd..a8c7a45 100644 --- a/src/openclaw_memory_bench/hybrid.py +++ b/src/openclaw_memory_bench/hybrid.py @@ -192,9 +192,15 @@ def build_two_stage_hybrid_report( if isinstance(rel_src, dict) and isinstance(rel_src.get("relevant_session_ids"), list): rel_ids = [str(x) for x in rel_src.get("relevant_session_ids") if str(x)] - must_ids = _coerce_str_list(must_row.get("retrieved_session_ids") if isinstance(must_row, dict) else None) - must_scores = _coerce_float_list(must_row.get("retrieved_scores") if isinstance(must_row, dict) else None) - must_latency = float(must_row.get("latency_ms") or 0.0) if isinstance(must_row, dict) else 0.0 + must_ids = _coerce_str_list( + must_row.get("retrieved_session_ids") if isinstance(must_row, dict) else None + ) + must_scores = _coerce_float_list( + must_row.get("retrieved_scores") if isinstance(must_row, dict) else None + ) + must_latency = ( + float(must_row.get("latency_ms") or 0.0) if isinstance(must_row, dict) else 0.0 + ) fallback_ids = _coerce_str_list( fallback_row.get("retrieved_session_ids") if isinstance(fallback_row, dict) else None @@ -202,7 +208,9 @@ def build_two_stage_hybrid_report( fallback_scores = _coerce_float_list( fallback_row.get("retrieved_scores") if isinstance(fallback_row, dict) else None ) - fallback_latency = float(fallback_row.get("latency_ms") or 0.0) if isinstance(fallback_row, dict) else 0.0 + fallback_latency = ( + float(fallback_row.get("latency_ms") or 0.0) if isinstance(fallback_row, dict) else 0.0 + ) if len(must_scores) < len(must_ids): must_scores = must_scores + [0.0] * (len(must_ids) - len(must_scores)) diff --git a/src/openclaw_memory_bench/metrics.py b/src/openclaw_memory_bench/metrics.py index dbbf798..ac8b3af 100644 --- a/src/openclaw_memory_bench/metrics.py +++ b/src/openclaw_memory_bench/metrics.py @@ -32,7 +32,9 @@ def score_retrieval(retrieved_ids: list[str], relevant_ids: list[str], k: int) - relevant = set(relevant_ids) if not relevant: - return RetrievalMetrics(hit_at_k=0.0, precision_at_k=0.0, recall_at_k=0.0, mrr=0.0, ndcg_at_k=0.0) + return RetrievalMetrics( + hit_at_k=0.0, precision_at_k=0.0, recall_at_k=0.0, mrr=0.0, ndcg_at_k=0.0 + ) binary = [1 if x in relevant else 0 for x in ranked] rel_count = sum(binary) diff --git a/src/openclaw_memory_bench/runner.py b/src/openclaw_memory_bench/runner.py index c4f1970..27838be 100644 --- a/src/openclaw_memory_bench/runner.py +++ b/src/openclaw_memory_bench/runner.py @@ -259,7 +259,9 @@ def run_retrieval_benchmark( "relevant_session_ids": q.relevant_session_ids, "retrieved_session_ids": retrieved_session_ids, "retrieved_observation_ids": [h.id for h in hits], - "retrieved_sources": [h.metadata.get("path") for h in hits if h.metadata.get("path")], + "retrieved_sources": [ + h.metadata.get("path") for h in hits if h.metadata.get("path") + ], "ingest_result": ingest_result, "latency_ms": dt_ms, "metrics": asdict(metrics), @@ -291,7 +293,9 @@ def run_retrieval_benchmark( qid_to_type = {q.question_id: q.question_type for q in questions} totals_by_type: Counter[str] = Counter(q.question_type for q in questions) - failed_by_type: Counter[str] = Counter(qid_to_type.get(f.get("question_id"), "unknown") for f in failures) + failed_by_type: Counter[str] = Counter( + qid_to_type.get(f.get("question_id"), "unknown") for f in failures + ) rows_by_type: dict[str, list[dict]] = {} for row in results: diff --git a/src/openclaw_memory_bench/validation.py b/src/openclaw_memory_bench/validation.py index 5661f72..e351b8a 100644 --- a/src/openclaw_memory_bench/validation.py +++ b/src/openclaw_memory_bench/validation.py @@ -27,7 +27,9 @@ def _require(condition: bool, path: str, message: str, errors: list[str]) -> Non def _require_non_empty_str(value: Any, path: str, errors: list[str]) -> None: - _require(isinstance(value, str) and value.strip() != "", path, "must be a non-empty string", errors) + _require( + isinstance(value, str) and value.strip() != "", path, "must be a non-empty string", errors + ) def _require_list(value: Any, path: str, errors: list[str]) -> list[Any]: @@ -128,25 +130,49 @@ def validate_retrieval_report_payload(report: dict[str, Any]) -> None: _require(isinstance(summary, dict), "report.summary", "must be an object", errors) if isinstance(summary, dict): for key in ("questions_total", "questions_succeeded", "questions_failed"): - _require(isinstance(summary.get(key), int), f"report.summary.{key}", "must be an integer", errors) + _require( + isinstance(summary.get(key), int), + f"report.summary.{key}", + "must be an integer", + errors, + ) _validate_metrics(summary, "report.summary", errors) breakdown = summary.get("failure_breakdown") - _require(isinstance(breakdown, dict), "report.summary.failure_breakdown", "must be an object", errors) + _require( + isinstance(breakdown, dict), + "report.summary.failure_breakdown", + "must be an object", + errors, + ) if isinstance(breakdown, dict): for key in ("by_code", "by_category", "by_phase"): obj = breakdown.get(key) - _require(isinstance(obj, dict), f"report.summary.failure_breakdown.{key}", "must be an object", errors) + _require( + isinstance(obj, dict), + f"report.summary.failure_breakdown.{key}", + "must be an object", + errors, + ) if isinstance(obj, dict): for kk, vv in obj.items(): - _require_non_empty_str(kk, f"report.summary.failure_breakdown.{key} key", errors) - _require(isinstance(vv, int), f"report.summary.failure_breakdown.{key}.{kk}", "must be an integer", errors) + _require_non_empty_str( + kk, f"report.summary.failure_breakdown.{key} key", errors + ) + _require( + isinstance(vv, int), + f"report.summary.failure_breakdown.{key}.{kk}", + "must be an integer", + errors, + ) latency = report.get("latency") _require(isinstance(latency, dict), "report.latency", "must be an object", errors) if isinstance(latency, dict): for key in ("search_ms_p50", "search_ms_p95", "search_ms_mean"): - _require(_is_number(latency.get(key)), f"report.latency.{key}", "must be a number", errors) + _require( + _is_number(latency.get(key)), f"report.latency.{key}", "must be a number", errors + ) results = _require_list(report.get("results"), "report.results", errors) for ri, row in enumerate(results): @@ -162,10 +188,17 @@ def validate_retrieval_report_payload(report: dict[str, Any]) -> None: ): _require_non_empty_str(row.get(key), f"{rpath}.{key}", errors) - for key in ("relevant_session_ids", "retrieved_session_ids", "retrieved_observation_ids", "retrieved_sources"): + for key in ( + "relevant_session_ids", + "retrieved_session_ids", + "retrieved_observation_ids", + "retrieved_sources", + ): _require(isinstance(row.get(key), list), f"{rpath}.{key}", "must be a list", errors) - _require(_is_number(row.get("latency_ms")), f"{rpath}.latency_ms", "must be a number", errors) + _require( + _is_number(row.get("latency_ms")), f"{rpath}.latency_ms", "must be a number", errors + ) metrics = row.get("metrics") _require(isinstance(metrics, dict), f"{rpath}.metrics", "must be an object", errors) @@ -183,7 +216,9 @@ def validate_retrieval_report_payload(report: dict[str, Any]) -> None: _require_non_empty_str(f.get("phase"), f"{fpath}.phase", errors) _require_non_empty_str(f.get("error_code"), f"{fpath}.error_code", errors) _require_non_empty_str(f.get("error_category"), f"{fpath}.error_category", errors) - _require(isinstance(f.get("retryable"), bool), f"{fpath}.retryable", "must be a boolean", errors) + _require( + isinstance(f.get("retryable"), bool), f"{fpath}.retryable", "must be a boolean", errors + ) _require_non_empty_str(f.get("exception_type"), f"{fpath}.exception_type", errors) _require_non_empty_str(f.get("error"), f"{fpath}.error", errors) @@ -191,7 +226,9 @@ def validate_retrieval_report_payload(report: dict[str, Any]) -> None: raise SchemaValidationError(errors) -def validate_required_keys(payload: dict[str, Any], keys: Sequence[str], *, path: str = "object") -> None: +def validate_required_keys( + payload: dict[str, Any], keys: Sequence[str], *, path: str = "object" +) -> None: errors: list[str] = [] for key in keys: _require(key in payload, f"{path}.{key}", "is required", errors) diff --git a/tests/test_hybrid_fusion.py b/tests/test_hybrid_fusion.py index 8665eab..b00e45c 100644 --- a/tests/test_hybrid_fusion.py +++ b/tests/test_hybrid_fusion.py @@ -6,7 +6,9 @@ from openclaw_memory_bench.hybrid import _rrf_merge, build_two_stage_hybrid_report _RRF_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "hybrid" / "rrf_tie_break_case.json" -_GATE_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "hybrid" / "stage2_budget_latency_case.json" +_GATE_FIXTURE_PATH = ( + Path(__file__).parent / "fixtures" / "hybrid" / "stage2_budget_latency_case.json" +) def _load_fixture(path: Path) -> dict: @@ -84,7 +86,10 @@ def test_append_fill_respects_stage1_order_before_stage2_fill() -> None: fusion_mode="append_fill", ) - assert report["results"][0]["retrieved_session_ids"] == fixture["expected"]["append_fill_ranked_top4"] + assert ( + report["results"][0]["retrieved_session_ids"] + == fixture["expected"]["append_fill_ranked_top4"] + ) def test_stage2_budget_and_latency_gate_receipts_from_fixture() -> None: @@ -110,6 +115,9 @@ def test_stage2_budget_and_latency_gate_receipts_from_fixture() -> None: assert row["retrieved_session_ids"] == row_expected["retrieved_session_ids"] assert two_stage["stage2_used"] is row_expected["stage2_used"] assert two_stage["stage2_skipped_budget"] is row_expected["stage2_skipped_budget"] - assert two_stage["stage2_candidates_considered"] == row_expected["stage2_candidates_considered"] + assert ( + two_stage["stage2_candidates_considered"] + == row_expected["stage2_candidates_considered"] + ) assert two_stage["stage2_added_count"] == row_expected["stage2_added_count"] assert row["latency_ms"] == row_expected["latency_ms"] diff --git a/tests/test_memu_adapter.py b/tests/test_memu_adapter.py index a00a0cf..02deac0 100644 --- a/tests/test_memu_adapter.py +++ b/tests/test_memu_adapter.py @@ -6,7 +6,7 @@ def test_extract_results_from_content_text_json() -> None: "content": [ { "type": "text", - "text": '{"results":[{"path":"/tmp/sessions/abc123.jsonl","snippet":"hello","score":0.9}]}' + "text": '{"results":[{"path":"/tmp/sessions/abc123.jsonl","snippet":"hello","score":0.9}]}', } ] } diff --git a/tests/test_phase_ab_compare_runner.py b/tests/test_phase_ab_compare_runner.py index e0f1f8a..0131062 100644 --- a/tests/test_phase_ab_compare_runner.py +++ b/tests/test_phase_ab_compare_runner.py @@ -6,7 +6,9 @@ from typing import Any -_HYBRID_GATE_FIXTURE_PATH = Path(__file__).parent / "fixtures" / "hybrid" / "stage2_budget_latency_case.json" +_HYBRID_GATE_FIXTURE_PATH = ( + Path(__file__).parent / "fixtures" / "hybrid" / "stage2_budget_latency_case.json" +) def _load_runner_module(): @@ -26,7 +28,10 @@ def _load_hybrid_gate_fixture() -> dict[str, Any]: def test_resolve_run_group_slug() -> None: runner = _load_runner_module() - assert runner._resolve_run_group(explicit_run_group="Deterministic Run 01", run_label="ignored") == "deterministic-run-01" + assert ( + runner._resolve_run_group(explicit_run_group="Deterministic Run 01", run_label="ignored") + == "deterministic-run-01" + ) def test_main_writes_stable_latest_pointer(monkeypatch, tmp_path, capsys) -> None: @@ -167,7 +172,9 @@ def test_main_can_build_hybrid_arm(monkeypatch, tmp_path, capsys) -> None: out_root = tmp_path / "out" - def _stub_report(*, run_id: str, top_k: int, retrieved: list[str], scores: list[float], latency_ms: float): + def _stub_report( + *, run_id: str, top_k: int, retrieved: list[str], scores: list[float], latency_ms: float + ): hit = 1.0 if any(x in {"s1", "s2"} for x in retrieved[:top_k]) else 0.0 recall = sum(1 for x in retrieved[:top_k] if x in {"s1", "s2"}) / 2.0 return { @@ -236,7 +243,9 @@ def _fake_run_lancedb(**kwargs): scores=scores, latency_ms=latency_ms, ) - report_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + report_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) return { "label": run_suffix, @@ -286,7 +295,9 @@ def _fake_run_lancedb(**kwargs): assert payload["hybrid_report"] is not None -def test_main_hybrid_stage2_budget_counts_when_latency_cap_enforced(monkeypatch, tmp_path, capsys) -> None: +def test_main_hybrid_stage2_budget_counts_when_latency_cap_enforced( + monkeypatch, tmp_path, capsys +) -> None: runner = _load_runner_module() fixture = _load_hybrid_gate_fixture() @@ -403,7 +414,9 @@ def _write_stub_payload(*, report_path: Path, run_suffix: str) -> dict[str, Any] metrics = _summary_from_results(payload["results"]) payload["summary"] = metrics["summary"] payload["latency"] = metrics["latency"] - report_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + report_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) return payload def _fake_run_lancedb(**kwargs): diff --git a/tests/test_qmd_adapter.py b/tests/test_qmd_adapter.py index 498481d..b07ef57 100644 --- a/tests/test_qmd_adapter.py +++ b/tests/test_qmd_adapter.py @@ -42,7 +42,9 @@ def test_search_maps_session_id_from_path(monkeypatch) -> None: } def _ok(*args, **kwargs): - return subprocess.CompletedProcess(args=["qmd"], returncode=0, stdout=json.dumps(payload), stderr="") + return subprocess.CompletedProcess( + args=["qmd"], returncode=0, stdout=json.dumps(payload), stderr="" + ) monkeypatch.setattr(subprocess, "run", _ok) @@ -61,7 +63,9 @@ def test_search_maps_non_empty_fixture_payload(monkeypatch) -> None: fixture_stdout = fixture_path.read_text(encoding="utf-8") def _ok(*args, **kwargs): - return subprocess.CompletedProcess(args=["qmd"], returncode=0, stdout=fixture_stdout, stderr="") + return subprocess.CompletedProcess( + args=["qmd"], returncode=0, stdout=fixture_stdout, stderr="" + ) monkeypatch.setattr(subprocess, "run", _ok) @@ -99,7 +103,9 @@ def test_search_wires_query_command_limit_and_extra_args(monkeypatch) -> None: def _ok(cmd, *args, **kwargs): captured["cmd"] = cmd captured["timeout"] = kwargs.get("timeout") - return subprocess.CompletedProcess(args=cmd, returncode=0, stdout='{"results": []}', stderr="") + return subprocess.CompletedProcess( + args=cmd, returncode=0, stdout='{"results": []}', stderr="" + ) monkeypatch.setattr(subprocess, "run", _ok)