Skip to content

Commit 9368399

Browse files
authored
Feature/evaluation metrics (#27)
* feat(evaluation): add simple faithfulness overlap metric with tests * feat(evaluation): add JSONL trace capture for RAG runs * feat(evaluation): add JSONL trace capture for RAG runs * feat(evaluation): add evaluation summary endpoint from trace aggregates
1 parent b054c8e commit 9368399

File tree

9 files changed

+205
-0
lines changed

9 files changed

+205
-0
lines changed

app/api/router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from app.retrieval.api import router as retrival_router
88
from app.ingestion.api import router as ingestion_router
99
from app.generation.api import router as generation_router
10+
from app.evaludation.api import router as evaluation_router
1011

1112

1213
router = APIRouter(prefix="/api/v1", tags=["api"])
@@ -22,3 +23,4 @@ async def ping() -> dict[str, Any]:
2223
router.include_router(retrival_router)
2324
router.include_router(ingestion_router)
2425
router.include_router(generation_router)
26+
router.include_router(evaluation_router)

app/evaludation/aggregate.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
from pathlib import Path
3+
4+
5+
TRACE_FILE = Path(__file__).parent.parent.parent / "var/traces" / "rage_trace.jsonl"
6+
7+
8+
def aggregate_traces() -> dict:
9+
"""
10+
Aggregate traces from the JSONL trace file.
11+
Read the file line by line, parse each JSON record, and compute
12+
average metrics like recall_k and faithfulness.
13+
14+
:return: Aggregated trace records.
15+
:rtype: dict
16+
"""
17+
18+
total = 0
19+
recall_k_sum = 0.0
20+
recall_n = 0 # recall_k_count
21+
faithfulness_sum = 0.0
22+
faithfulness_n = 0 # faithfulness_count
23+
24+
if not TRACE_FILE.exists():
25+
return {
26+
"runs": 0,
27+
"avg_recall_k": None,
28+
"avg_faithfulness": None,
29+
}
30+
31+
with TRACE_FILE.open("r", encoding="utf-8") as f:
32+
for line in f:
33+
total += 1
34+
r = json.loads(line)
35+
if r.get("recall_k") is not None:
36+
recall_k_sum += r["recall_k"]
37+
recall_n += 1
38+
if r.get("faithfulness") is not None:
39+
faithfulness_sum += r["faithfulness"]
40+
faithfulness_n += 1
41+
42+
return {
43+
"runs": total,
44+
"avg_recall_k": round(recall_k_sum / recall_n, 4) if recall_n > 0 else None,
45+
"avg_faithfulness": round(faithfulness_sum / faithfulness_n, 4)
46+
if faithfulness_n > 0
47+
else None,
48+
}

app/evaludation/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from fastapi import APIRouter
2+
3+
from app.evaludation.aggregate import aggregate_traces
4+
5+
6+
router = APIRouter(prefix="/eval", tags=["evaluation"])
7+
8+
9+
@router.get("/summary")
10+
async def evaluation_summary() -> dict:
11+
return aggregate_traces()

app/evaludation/faithfulness.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import re
2+
from typing import Iterable
3+
4+
5+
def _tokens(text: str) -> set[str]:
6+
"""
7+
Tokenize the input text into a set of normalized words.
8+
"""
9+
return set(re.findall(r"\b\w+\b", text.lower()))
10+
11+
12+
def faithfulness_overlap(answer: str, contexts: Iterable[str]) -> float:
13+
"""
14+
Calculate the faithfulness overlap score between the answer and provided contexts.
15+
The score is the ratio of overlapping words to total words in the answer.
16+
17+
Args:
18+
answer (str): The generated answer text.
19+
contexts (Iterable[str]): A list of context strings.
20+
21+
Returns:
22+
float: The faithfulness overlap score (0.0 to 1.0).
23+
"""
24+
# Normalize and tokenize the answer
25+
answer_tokens = _tokens(answer)
26+
# means it has no tokens, e.g., empty strings
27+
if not answer_tokens:
28+
return 0.0
29+
30+
# Normalize and tokenize all contexts
31+
context_tokens = set()
32+
for context in contexts:
33+
context_tokens.update(_tokens(context))
34+
35+
# Calculate overlap score
36+
# intersection over answer tokens
37+
return round(len(answer_tokens & context_tokens) / len(answer_tokens), 4)

app/evaludation/recall.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Iterable, Set
2+
3+
4+
def recall_at_k(retrieved_ids: Iterable[str], relevant_ids: Set[str], k: int) -> float:
5+
"""
6+
Calculate Recall@K for a set of retrieved document IDs against relevant document IDs.
7+
8+
Args:
9+
retrieved_ids (Iterable[str]): An iterable of retrieved document IDs.
10+
relevant_ids (Set[str]): A set of relevant document IDs.
11+
k (int): The cutoff rank K.
12+
13+
Returns:
14+
float: The Recall@K value.
15+
"""
16+
if k <= 0:
17+
raise ValueError("k must be a positive integer")
18+
19+
relevant_ids_set = set(relevant_ids)
20+
21+
# Limit retrieved IDs to top K
22+
if not relevant_ids_set:
23+
return 0.0
24+
25+
# chunk retrieved_ids to top k
26+
top_k_retrieved_ids_set = set(list(retrieved_ids)[:k])
27+
28+
return round(
29+
len(relevant_ids_set & top_k_retrieved_ids_set) / len(relevant_ids_set), 4
30+
)

app/evaludation/trace_writer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
import time
3+
from pathlib import Path
4+
from typing import Iterable
5+
6+
7+
TRACE_DIR = Path(__file__).parent.parent.parent / "var/traces"
8+
TRACE_DIR.mkdir(parents=True, exist_ok=True)
9+
TRACE_FILE = TRACE_DIR / "rage_trace.jsonl"
10+
11+
12+
def write_trace(
13+
query: str,
14+
retrieved_ids: Iterable[str],
15+
answer_text: str,
16+
recall_k: float | None = None,
17+
faithfulness: float | None = None,
18+
):
19+
"""
20+
write trace of each generation to a JSONL file.
21+
22+
:param query: Description
23+
:type query: str
24+
:param retrieved_ids: Description
25+
:type retrieved_ids: Iterable[str]
26+
:param answer_text: Description
27+
:type answer_text: str
28+
:param recall_k: Description
29+
:type recall_k: float | None
30+
:param faithfulness: Description
31+
:type faithfulness: float | None
32+
"""
33+
record = {
34+
"ts": time.time(),
35+
"query": query,
36+
"retrieved_ids": list(retrieved_ids),
37+
"answer": answer_text,
38+
"recall_k": recall_k,
39+
"faithfulness": faithfulness,
40+
}
41+
42+
with TRACE_FILE.open("a", encoding="utf-8") as f:
43+
f.write(json.dumps(record) + "\n")

app/generation/service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from app.generation.prompt_builder import build_prompt
66

77
from app.core.interfaces import BaseGenerator, BaseRetriever
8+
from app.evaludation.trace_writer import write_trace
89

910

1011
logger = get_logger(__name__)
@@ -52,6 +53,13 @@ async def generate_answer(
5253

5354
logger.info(f"Generated answer for query='{req.query}'")
5455

56+
# Write trace for the generation
57+
write_trace(
58+
query=req.query,
59+
retrieved_ids=[chunk["doc_id"] for chunk in retrieved_chunks],
60+
answer_text=answer_text,
61+
)
62+
5563
return GenerationResponse(
5664
query=req.query,
5765
answer=GenerateAnswer(
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from app.evaludation.faithfulness import faithfulness_overlap
2+
3+
4+
def test_failthfulness_basic():
5+
answer = "RAG combines retrieval and generation"
6+
contexts = ["retrieval augmented generation uses retrieval"]
7+
assert faithfulness_overlap(answer, contexts) == 0.4

tests/evaluation/test_recall.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from app.evaludation.recall import recall_at_k
2+
3+
4+
def test_recall_at_k_basic():
5+
"""
6+
Test basic functionality of recall_at_k.
7+
it should return correct recall value for simple cases.
8+
"""
9+
retrieved_ids = ["doc1", "doc2", "doc3"]
10+
relevant_ids = {
11+
"doc2",
12+
}
13+
k = 3
14+
# it shall be 1/1 = 1.0
15+
# explaination:
16+
# relevant_ids has one document "doc2"
17+
# retrieved_ids has "doc2" within top 3
18+
#
19+
assert recall_at_k(retrieved_ids, relevant_ids, k) == 1.0

0 commit comments

Comments
 (0)