Skip to content

Commit 81ff565

Browse files
committed
feat(generator): add generator adapters with OpenAI, Ollama, and mock implementations
1 parent 1bdb856 commit 81ff565

File tree

11 files changed

+153
-21
lines changed

11 files changed

+153
-21
lines changed

app/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pydantic import Field
22
from pydantic_settings import BaseSettings
33

4+
from app.core.constants import GeneratorBackend
5+
46

57
class Settings(BaseSettings):
68
APP_NAME: str = "rag_mastery"
@@ -9,12 +11,15 @@ class Settings(BaseSettings):
911
LOG_LEVEL: str = Field(default="INFO", alias="LOG_LEVEL")
1012
HOST: str = Field(default="0.0.0.0", alias="HOST")
1113
PORT: int = Field(default=8000, alias="PORT")
14+
USE_REAL_GENERATOR: bool = False
15+
GENERATOR_BACKEND: GeneratorBackend = GeneratorBackend.Mock
1216

1317
model_config = {
1418
"env_file": ".env",
1519
"env_file_encoding": "utf-8",
1620
"extra": "ignore",
1721
"populate_by_name": True, # allowing alias mapping
22+
"use_enum_values": True, # store enum values directly
1823
}
1924

2025

app/core/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import StrEnum
22

33

4+
# Ingestion status constants used across the application
45
class IngestionStatus(StrEnum):
56
Accepted = "accepted"
67
Processing = "processing"
@@ -9,3 +10,14 @@ class IngestionStatus(StrEnum):
910

1011

1112
DefaultTopK: int = 5
13+
14+
15+
class GeneratorBackend(StrEnum):
16+
"""
17+
Supported generator backends.
18+
Reason: to avoid hardcoding strings across the codebase.
19+
"""
20+
21+
Mock = "mock"
22+
OPENAI = "openai"
23+
Ollama = "ollama"

app/core/metrics.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
from fastapi import Response
4-
from prometheus_client import CONTENT_TYPE_LATEST, Counter, generate_latest
4+
from prometheus_client import CONTENT_TYPE_LATEST, Counter, generate_latest, Histogram
55

66
# Identify which service (ingestion, retrieval, generation, etc.)
77
APP_NAME = os.getenv("APP_NAME", "rag_architect")
@@ -19,6 +19,20 @@
1919
["app_name", "method", "endpoint", "http_status"], # metric dimensions
2020
)
2121

22+
# Prompt build duration histogram
23+
APP_PROMPT_BUILD_SECONDS = Histogram(
24+
"app_prompt_build_seconds",
25+
"Time taken to build prompts",
26+
["app_name"],
27+
)
28+
29+
# LLM generation latency histogram
30+
APP_GENERATION_LATENCY_SECONDS = Histogram(
31+
"app_generation_latency_seconds",
32+
"Time taken for LLM answer generation",
33+
["app_name"],
34+
)
35+
2236

2337
def record_request(method: str, endpoint: str, http_status: str):
2438
"""
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
from app.core.interfaces import BaseGenerator
3+
from app.core.logging import get_logger
4+
5+
6+
logger = get_logger(__name__)
7+
8+
9+
class OllamaGenerator(BaseGenerator):
10+
"""
11+
Async adapter for Ollama-style LLMs.
12+
"""
13+
14+
def __init__(self, model_name: str = "llama3"):
15+
self.model_name = model_name
16+
17+
async def generate(self, prompt: str) -> str:
18+
logger.debug(f"[OllamaGenerator] Would call model={self.model_name}")
19+
await asyncio.sleep(0.5) # simulate network delay
20+
# Placeholder for actual Ollama API call
21+
return f"[Simulated Ollama: {self.model_name} for prompt: {prompt[:60]}...]"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import asyncio
2+
3+
from app.core.interfaces import BaseGenerator
4+
from app.core.logging import get_logger
5+
6+
7+
logger = get_logger(__name__)
8+
9+
10+
class OpenAIGenerator(BaseGenerator):
11+
"""
12+
Async adapter for OpenAI-style LLMs.
13+
"""
14+
15+
def __init__(self, model_name: str = "gpt-4-turbo"):
16+
self.model_name = model_name
17+
18+
async def generate(self, prompt: str) -> str:
19+
logger.debug(f"[OpenAIGenerator] would call model={self.model_name}")
20+
await asyncio.sleep(0.5) # simulate network delay
21+
return f"[Simulated OpenAI: {self.model_name} for prompt: {prompt[:60]}...]"

app/generation/api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from fastapi import APIRouter, status, Depends
22

3-
from app.core.interfaces import BaseRetriever
3+
from app.core.interfaces import BaseGenerator, BaseRetriever
44
from app.core.logging import get_logger
5-
from app.generation.deps import get_retriever
5+
from app.generation.deps import get_generator, get_retriever
66
from app.generation.models import GenerationRequest, GenerationResponse
77
from app.generation.service import generate_answer
88

@@ -18,9 +18,11 @@
1818
summary="Generate answer based on query and retrieved contexts",
1919
)
2020
async def generate_endpoint(
21-
req: GenerationRequest, retriever: BaseRetriever = Depends(get_retriever)
21+
req: GenerationRequest,
22+
retriever: BaseRetriever = Depends(get_retriever),
23+
generator: BaseGenerator = Depends(get_generator),
2224
):
2325
logger.info(
2426
"Received generation request", query=req.query, context_size=req.context_size
2527
)
26-
return await generate_answer(req, retriever)
28+
return await generate_answer(req, retriever, generator)

app/generation/deps.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from app.core.constants import GeneratorBackend
12
from app.core.interfaces import BaseRetriever
3+
from app.generation.adapters.ollama_adapter import OllamaGenerator
24
from app.generation.mock_generator import MockGenerator
35
from app.retrieval.models import RetrievalRequest
46
from app.retrieval.service import retrieve_documents
57
from app.core.repositories import global_vector_repo
8+
from app.core.config import settings
69

710

811
class RetrievalAdapter(BaseRetriever):
@@ -35,6 +38,20 @@ def get_retriever() -> BaseRetriever:
3538
return _retriever
3639

3740

38-
async def get_generator():
39-
# swap with real generator later
40-
yield MockGenerator()
41+
async def get_generator(
42+
use_real: bool = settings.USE_REAL_GENERATOR,
43+
backend: GeneratorBackend = settings.GENERATOR_BACKEND,
44+
):
45+
"""
46+
Returns a generator instance based on configuration.
47+
"""
48+
# precedence wise use_real overrides backend
49+
if not use_real:
50+
yield MockGenerator()
51+
else:
52+
if backend == GeneratorBackend.Mock:
53+
yield MockGenerator()
54+
elif backend == GeneratorBackend.Ollama:
55+
yield OllamaGenerator()
56+
else:
57+
yield OllamaGenerator() # Default to OllamaGenerator for now

app/generation/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class GenerateAnswer(BaseModel):
2020
"""
2121

2222
text: str
23-
used_contexts: list[dict[str, Any]]
23+
used_context: list[dict[str, Any]]
2424

2525

2626
class GenerationResponse(BaseModel):

app/generation/service.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
import time
12
from app.core.logging import get_logger
3+
from app.core.metrics import APP_NAME, APP_PROMPT_BUILD_SECONDS
24
from app.generation.models import GenerateAnswer, GenerationRequest, GenerationResponse
35
from app.generation.prompt_builder import build_prompt
46

5-
from app.core.interfaces import BaseRetriever
7+
from app.core.interfaces import BaseGenerator, BaseRetriever
68

79

810
logger = get_logger(__name__)
911

1012

1113
async def generate_answer(
12-
req: GenerationRequest, retriever: BaseRetriever
14+
req: GenerationRequest,
15+
retriever: BaseRetriever,
16+
generator: BaseGenerator,
1317
) -> GenerationResponse:
1418
"""
1519
generate handles the text generation process by retrieving relevant documents
@@ -36,20 +40,22 @@ async def generate_answer(
3640
"""
3741
logger.debug(f"Retrieved {len(retrieved_chunks)} chunks")
3842

39-
_ = build_prompt(req.query, retrieved_chunks)
43+
# prompt build metrics
44+
t0 = time.monotonic()
45+
prompt = build_prompt(req.query, retrieved_chunks)
46+
APP_PROMPT_BUILD_SECONDS.labels(app_name=APP_NAME).observe(time.monotonic() - t0)
4047

41-
synthesized = (
42-
" ".join(chunk["doc_id"] for chunk in retrieved_chunks) or "No context found."
43-
)
48+
# Generation latency metrics
49+
t1 = time.monotonic()
50+
answer_text = await generator.generate(prompt=prompt)
51+
APP_PROMPT_BUILD_SECONDS.labels(app_name=APP_NAME).observe(time.monotonic() - t1)
4452

45-
logger.info(
46-
f"Generated answer for query='{req.query}' using {len(retrieved_chunks)}"
47-
)
53+
logger.info(f"Generated answer for query='{req.query}'")
4854

4955
return GenerationResponse(
5056
query=req.query,
5157
answer=GenerateAnswer(
52-
text=f"Mock answer: {synthesized}",
53-
used_contexts=retrieved_chunks,
58+
text=answer_text,
59+
used_context=retrieved_chunks,
5460
),
5561
)

tests/generation/test_deps.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from app.core.constants import GeneratorBackend
2+
from app.generation.adapters.ollama_adapter import OllamaGenerator
3+
from app.generation.deps import get_generator
4+
from app.generation.mock_generator import MockGenerator
5+
6+
7+
import pytest
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_get_generator_returns_mock_when_use_real_false_but_backend_not_provided():
12+
"""
13+
When use_real is False and no backend is provided, should return MockGenerator.
14+
"""
15+
gen = await anext(get_generator(use_real=False))
16+
assert isinstance(gen, MockGenerator)
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_get_generator_returns_mock_when_use_real_true_and_backend_mock():
21+
"""
22+
When use_real is True and backend is Mock, should return MockGenerator.
23+
"""
24+
gen = await anext(get_generator(use_real=True, backend=GeneratorBackend.Mock))
25+
assert isinstance(gen, MockGenerator)
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_get_generator_returns_ollama_when_use_real_false_and_backend_ollama():
30+
"""
31+
When use_real is True and backend is Ollama, should return OllamaGenerator.
32+
"""
33+
gen = await anext(get_generator(use_real=True, backend=GeneratorBackend.Ollama))
34+
assert isinstance(gen, OllamaGenerator)

0 commit comments

Comments
 (0)