Skip to content

Commit 223841b

Browse files
committed
refactor(generation): fix imports, function names, and update comments for prompt building (#19)
1 parent f13bda8 commit 223841b

File tree

9 files changed

+153
-16
lines changed

9 files changed

+153
-16
lines changed

app/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pydantic import Field
22
from pydantic_settings import BaseSettings
33

4+
from app.core.constants import GeneratorBackend
5+
46

57
class Settings(BaseSettings):
68
APP_NAME: str = "rag_mastery"
@@ -9,12 +11,15 @@ class Settings(BaseSettings):
911
LOG_LEVEL: str = Field(default="INFO", alias="LOG_LEVEL")
1012
HOST: str = Field(default="0.0.0.0", alias="HOST")
1113
PORT: int = Field(default=8000, alias="PORT")
14+
USE_REAL_GENERATOR: bool = False
15+
GENERATOR_BACKEND: GeneratorBackend = GeneratorBackend.Mock
1216

1317
model_config = {
1418
"env_file": ".env",
1519
"env_file_encoding": "utf-8",
1620
"extra": "ignore",
1721
"populate_by_name": True, # allowing alias mapping
22+
"use_enum_values": True, # store enum values directly
1823
}
1924

2025

app/core/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import StrEnum
22

33

4+
# Ingestion status constants used across the application
45
class IngestionStatus(StrEnum):
56
Accepted = "accepted"
67
Processing = "processing"
@@ -9,3 +10,14 @@ class IngestionStatus(StrEnum):
910

1011

1112
DefaultTopK: int = 5
13+
14+
15+
class GeneratorBackend(StrEnum):
16+
"""
17+
Supported generator backends.
18+
Reason: to avoid hardcoding strings across the codebase.
19+
"""
20+
21+
Mock = "mock"
22+
OPENAI = "openai"
23+
Ollama = "ollama"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
from app.core.interfaces import BaseGenerator
3+
from app.core.logging import get_logger
4+
5+
6+
logger = get_logger(__name__)
7+
8+
9+
class OllamaGenerator(BaseGenerator):
10+
"""
11+
Async adapter for Ollama-style LLMs.
12+
"""
13+
14+
def __init__(self, model_name: str = "llama3"):
15+
self.model_name = model_name
16+
17+
async def generate(self, prompt: str) -> str:
18+
logger.debug(f"[OllamaGenerator] Would call model={self.model_name}")
19+
await asyncio.sleep(0.5) # simulate network delay
20+
# Placeholder for actual Ollama API call
21+
return f"[Simulated Ollama: {self.model_name} for prompt: {prompt[:60]}...]"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
3+
from app.core.interfaces import BaseGenerator
4+
from app.core.logging import get_logger
5+
6+
7+
logger = get_logger(__name__)
8+
9+
10+
class OpenAIGenerator(BaseGenerator):
11+
"""
12+
Async adapter for OpenAI-style LLMs.
13+
"""
14+
15+
def __init__(self, model_name: str = "gpt-4-turbo"):
16+
self.model_name = model_name
17+
18+
async def generate(self, prompt: str) -> str:
19+
logger.debug(f"[OpenAIGenerator] would call model={self.model_name}")
20+
await asyncio.sleep(0.5) # simulate network delay
21+
return f"[Simulated OpenAI: {self.model_name} for prompt: {prompt[:60]}...]"

app/generation/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from app.core.logging import get_logger
55
from app.generation.deps import get_retriever
66
from app.generation.models import GenerationRequest, GenerationResponse
7-
from app.generation.service import generate_anwer
7+
from app.generation.service import generate_answer
88

99
logger = get_logger(__name__)
1010

@@ -23,4 +23,4 @@ async def generate_endpoint(
2323
logger.info(
2424
"Received generation request", query=req.query, context_size=req.context_size
2525
)
26-
return await generate_anwer(req, retriever)
26+
return await generate_answer(req, retriever)

app/generation/deps.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from app.core.constants import GeneratorBackend
12
from app.core.interfaces import BaseRetriever
3+
from app.generation.adapters.ollama_adapter import OllamaGenerator
24
from app.generation.mock_generator import MockGenerator
35
from app.retrieval.models import RetrievalRequest
46
from app.retrieval.service import retrieve_documents
57
from app.core.repositories import global_vector_repo
8+
from app.core.config import settings
69

710

811
class RetrievalAdapter(BaseRetriever):
@@ -35,6 +38,20 @@ def get_retriever() -> BaseRetriever:
3538
return _retriever
3639

3740

38-
async def get_generator():
39-
# swap with real generator later
40-
yield MockGenerator()
41+
async def get_generator(
42+
use_real: bool = settings.USE_REAL_GENERATOR,
43+
backend: GeneratorBackend = settings.GENERATOR_BACKEND,
44+
):
45+
"""
46+
Returns a generator instance based on configuration.
47+
"""
48+
# precedence wise use_real overrides backend
49+
if not use_real:
50+
yield MockGenerator()
51+
else:
52+
if backend == GeneratorBackend.Mock:
53+
yield MockGenerator()
54+
elif backend == GeneratorBackend.Ollama:
55+
yield OllamaGenerator()
56+
else:
57+
yield OllamaGenerator() # Default to OllamaGenerator for now

app/generation/prompt_builder.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
def build_prompt(query: str, contexts: list[dict]) -> str:
2+
"""
3+
Assembles a prompt text from retrieved contexts and the input query.
4+
5+
Expects contexts as list of dicts, e.g.:
6+
[
7+
{"doc_id": "1", "score": 0.8, "metadata": {"text": "FastAPI is async..."}},
8+
...
9+
]
10+
"""
11+
if not contexts:
12+
joined_contexts = "[No relevant context found.]"
13+
14+
else:
15+
context_texts = []
16+
for c in contexts:
17+
meta = c.get("metadata", {})
18+
text = meta.get("text")
19+
if text:
20+
context_texts.append(text.strip())
21+
else:
22+
context_texts.append(f"[Doc:{c.get('doc_id', 'unknown')}]")
23+
24+
joined_contexts = "\n\n".join(context_texts)
25+
return (
26+
f"### Contexts:\n{joined_contexts}\n\n"
27+
f"### Question:\n{query.strip()}\n\n"
28+
"### Answer:\n"
29+
)

app/generation/service.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import asyncio
2-
31
from app.core.logging import get_logger
42
from app.generation.models import GenerateAnswer, GenerationRequest, GenerationResponse
3+
from app.generation.prompt_builder import build_prompt
54

65
from app.core.interfaces import BaseRetriever
76

87

98
logger = get_logger(__name__)
109

1110

12-
async def generate_anwer(
11+
async def generate_answer(
1312
req: GenerationRequest, retriever: BaseRetriever
1413
) -> GenerationResponse:
1514
"""
@@ -24,22 +23,21 @@ async def generate_anwer(
2423
GenerationResponse: The response containing the original query and generated answer.
2524
"""
2625
# retrieve top-k relevant documents based on the query
26+
logger.info(f"Generation started for query='{req.query}'")
27+
2728
retrieved_chunks = await retriever.retrieve(req.query, req.context_size)
2829
"""
2930
e.g: retrieved_chunks = [
30-
{"doc_id": "1", "content": "Document content 1", "metadata": {"source": "source1"}},
31-
{"doc_id": "2", "content": "Document content 2", "metadata": {"source": "source2"}},
31+
{"doc_id": "1", "score": 0.8, "metadata": {"source": "source1"}},
32+
{"doc_id": "2", "score": 0.7, "metadata": {"source": "source2"}},
3233
...
3334
]
35+
Note: Does not include original text; only doc_id and metadata.
3436
"""
37+
logger.debug(f"Retrieved {len(retrieved_chunks)} chunks")
3538

36-
await asyncio.sleep(0.1) # simulate llm latency
39+
_ = build_prompt(req.query, retrieved_chunks)
3740

38-
# e.g context = [
39-
# {"doc_id": "1", "content": "Document content 1", "metadata": {"source": "source1"}},
40-
# {"doc_id": "2", "content": "Document content 2", "metadata": {"source": "source2"}},
41-
# ...
42-
# ]
4341
synthesized = (
4442
" ".join(chunk["doc_id"] for chunk in retrieved_chunks) or "No context found."
4543
)

tests/generation/test_deps.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from app.core.constants import GeneratorBackend
2+
from app.generation.adapters.ollama_adapter import OllamaGenerator
3+
from app.generation.deps import get_generator
4+
from app.generation.mock_generator import MockGenerator
5+
6+
7+
import pytest
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_get_generator_returns_mock_when_use_real_false_but_backend_not_provided():
12+
"""
13+
When use_real is False and no backend is provided, should return MockGenerator.
14+
"""
15+
gen = await anext(get_generator(use_real=False))
16+
assert isinstance(gen, MockGenerator)
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_get_generator_returns_mock_when_use_real_true_and_backend_mock():
21+
"""
22+
When use_real is True and backend is Mock, should return MockGenerator.
23+
"""
24+
gen = await anext(get_generator(use_real=True, backend=GeneratorBackend.Mock))
25+
assert isinstance(gen, MockGenerator)
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_get_generator_returns_ollama_when_use_real_false_and_backend_ollama():
30+
"""
31+
When use_real is True and backend is Ollama, should return OllamaGenerator.
32+
"""
33+
gen = await anext(get_generator(use_real=True, backend=GeneratorBackend.Ollama))
34+
assert isinstance(gen, OllamaGenerator)

0 commit comments

Comments
 (0)