Skip to content

Commit 6386f2b

Browse files
authored
Add docs splitter and cross encoder based re-ranker; (#5)
Fix the formatting;
1 parent 6c6d36a commit 6386f2b

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed

app.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from enum import Enum
23
from typing import List, Optional, Union
34

@@ -10,6 +11,10 @@
1011
OpenAIEmbeddingModel,
1112
SentenceTransformerEmbeddingModel,
1213
)
14+
from reranker import get_scores
15+
16+
# from langchain.schema import Document
17+
from splitter import get_split_documents_using_token_based
1318

1419
load_dotenv()
1520

@@ -35,6 +40,22 @@ class RequestSchemaForEmbeddings(BaseModel):
3540
base_url: Optional[str] = None
3641

3742

43+
class RequestSchemaForTextSplitter(BaseModel):
44+
"""Request Schema"""
45+
46+
model: str
47+
documents: str
48+
chunk_size: int
49+
chunk_overlap: int
50+
51+
52+
class RequestSchemaForReRankers(BaseModel):
53+
"""Request Schema"""
54+
55+
query: str
56+
documents: List[str]
57+
58+
3859
@app.get("/")
3960
async def home():
4061
"""Returns a message"""
@@ -70,3 +91,18 @@ def generate(em_model, texts):
7091
elif type_model == EmbeddingModelType.OPENAI:
7192
embedding_model = OpenAIEmbeddingModel(model=name_model)
7293
return generate(em_model=embedding_model, texts=texts)
94+
95+
96+
@app.post("/split_docs_based_on_tokens")
97+
async def get_split_docs(item: RequestSchemaForTextSplitter):
98+
"""Splits the documents using the model tokenization method"""
99+
docs = json.loads(item.documents)
100+
return get_split_documents_using_token_based(
101+
model_name=item.model, documents=docs, chunk_size=item.chunk_size, chunk_overlap=item.chunk_overlap
102+
)
103+
104+
105+
@app.post("/docs_reranking_scores")
106+
async def get_reranked_docs(item: RequestSchemaForReRankers):
107+
"""Get reranked documents"""
108+
return get_scores(item.query, item.documents)

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ services:
33
build: .
44
volumes:
55
- embedding_models:/opt/models
6+
- .:/code
67
command: bash -c 'uvicorn app:app --host=0.0.0.0 --port=8000'
78
ports:
89
- "8000:8000"

reranker.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import List
2+
3+
import torch.nn.functional as F
4+
from sentence_transformers import CrossEncoder
5+
from torch import Tensor
6+
7+
8+
def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"):
9+
"""Get the scores"""
10+
model = CrossEncoder(model_name=model_name, max_length=512)
11+
doc_tuple = [(query, doc) for doc in documents]
12+
scores = model.predict(doc_tuple)
13+
return F.softmax(Tensor(scores), dim=0).tolist()

splitter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List
2+
3+
from langchain.schema import Document
4+
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
5+
6+
7+
def langchain_document_to_dict(doc: Document):
8+
"""
9+
Converts langchain Document to dictionary
10+
"""
11+
return {"page_content": doc.page_content, "metadata": doc.metadata}
12+
13+
14+
def dict_to_langchain_document(doc: dict):
15+
"""
16+
Converts dictionary to Langchain docuemnt
17+
"""
18+
return Document(page_content=doc["page_content"], metadata=doc["metadata"])
19+
20+
21+
def get_split_documents_using_token_based(model_name: str, documents: List[dict], chunk_size: int, chunk_overlap: int):
22+
"""
23+
Splits documents into multiple chunks using Sentence Transformer
24+
token based.
25+
"""
26+
splitter = SentenceTransformersTokenTextSplitter(
27+
chunk_overlap=chunk_overlap, model_name=model_name, tokens_per_chunk=chunk_size
28+
)
29+
langchain_docs = [dict_to_langchain_document(d) for d in documents]
30+
splitted_docs = splitter.split_documents(documents=langchain_docs)
31+
return [langchain_document_to_dict(d) for d in splitted_docs]

0 commit comments

Comments
 (0)