Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions backend/app/api/document_router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from fastapi import (
APIRouter,
Body,
Expand All @@ -9,11 +10,13 @@
UploadFile,
)
from fastapi.responses import JSONResponse, RedirectResponse
from http import HTTPStatus
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.text_splitter import CharacterTextSplitter
import os

from sqlalchemy.sql.operators import is_
from sqlmodel import Session, col, select
import tempfile
from typing import Annotated, Optional, Sequence, Union
Expand Down Expand Up @@ -109,14 +112,42 @@ async def upload_and_process_file(
async def list_documents(
session: Session = Depends(get_session),
current_user: User = Depends(get_current_user),
) -> Sequence[Document]:
) -> Sequence[DocumentWithChats]:
query = (
select(Document)
.where(Document.user_id == current_user.id)
.where(Document.user_id == current_user.id, is_(Document.deleted_at, None))
.order_by(col(Document.created_at).desc())
)
documents = list(session.exec(query))
return documents

# Create a new list with filtered chats
docs_with_filtered_chats = [
DocumentWithChats(
**doc.dict(), chats=[chat for chat in doc.chats if chat.deleted_at is None]
)
for doc in documents
]

return docs_with_filtered_chats


# Delete uploaded document (set a deleted_at column to current timestamp)
@document_router.delete("/{document_id}")
async def delete_document(
document_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_user),
) -> Response:
query = select(Document).where(
Document.user_id == current_user.id, Document.id == document_id
)
document = session.exec(query).first()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
document.deleted_at = datetime.datetime.utcnow()
session.add(document)
session.commit()
return Response(status_code=HTTPStatus.NO_CONTENT)


@document_router.get("/{document_id}/new_chat")
Expand All @@ -126,6 +157,11 @@ async def create_new_chat(
current_user: User = Depends(get_current_user),
) -> int:
# TODO: Check that the document belongs to the user
document = session.get(Document, document_id)
if document.user_id != current_user.id:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="Access not permitted"
)
new_chat = Chat(document_id=document_id)
session.add(new_chat)
session.commit()
Expand All @@ -134,6 +170,32 @@ async def create_new_chat(
return new_chat.id # type: ignore


@document_router.delete("/{document_id}/chat/{chat_id}")
async def delete_chat(
document_id: int,
chat_id: int,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_user),
) -> Response:
query = (
select(Chat)
.join(Document)
.where(
Chat.id == chat_id,
Chat.document_id == document_id,
is_(Chat.deleted_at, None),
Document.user_id == current_user.id,
)
)
chat = session.exec(query).first()
if chat is None:
raise HTTPException(status_code=404, detail="Chat not found")
chat.deleted_at = datetime.datetime.utcnow()
session.add(chat)
session.commit()
return Response(status_code=HTTPStatus.NO_CONTENT)


@document_router.get("/{document_id}/chat/{chat_id}")
async def retrieve_chat(
document_id: int,
Expand All @@ -142,7 +204,9 @@ async def retrieve_chat(
current_user: User = Depends(get_current_user),
) -> Sequence[ChatMessage]:
query = select(ChatMessage).where(
ChatMessage.chat_id == chat_id, ChatMessage.user_id == current_user.id
ChatMessage.chat_id == chat_id,
ChatMessage.user_id == current_user.id,
is_(ChatMessage.deleted_at, None),
)
chat_messages = session.exec(query)
return list(chat_messages)
Expand All @@ -157,7 +221,10 @@ def get_k_similar_chunks(
) -> Sequence[str]:
query = (
select(VectorEmbedding.content)
.where(VectorEmbedding.document_id == document_id)
.where(
VectorEmbedding.document_id == document_id,
# is_(VectorEmbedding.deleted_at, None),
)
.order_by(VectorEmbedding.embedding.l2_distance(query_embedding)) # type: ignore
.limit(k)
)
Expand Down Expand Up @@ -201,7 +268,7 @@ async def get_ai_response(
chat = ChatOpenAI(model=model, temperature=temperature) # type: ignore
previous_messages_query = (
select(ChatMessage)
.where(ChatMessage.chat_id == chat_id)
.where(ChatMessage.chat_id == chat_id, is_(ChatMessage.deleted_at, None))
.order_by(col(ChatMessage.created_at).desc())
.limit(2)
)
Expand Down
3 changes: 3 additions & 0 deletions backend/app/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class VectorEmbedding(SQLModel, table=True):
content: str = Field(...)

created_at: datetime = Field(default_factory=datetime.now)
# deleted_at: datetime = Field(default=None)


class ChatOriginator(str, Enum):
Expand All @@ -42,6 +43,8 @@ class Chat(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
document_id: int = Field(..., foreign_key="document.id")
document: Document = Relationship(back_populates="chats")
created_at: datetime = Field(default_factory=datetime.now)
deleted_at: Optional[datetime] = Field(default=None)


class ChatMessage(SQLModel, table=True):
Expand Down
31 changes: 31 additions & 0 deletions backend/app/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Setup per https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/

import datetime
from fastapi.testclient import TestClient
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -71,6 +72,21 @@ async def documents(session: Session, user: User) -> List[Document]:
return documents


@pytest_asyncio.fixture
async def semi_deleted_documents(session: Session, user: User) -> List[Document]:
documents = [
Document(title="Document 1", user_id=user.id),
Document(title="Document 2", user_id=user.id, deleted_at=datetime.datetime.now()),
]

session.add_all(documents)
session.commit()
for document in documents:
session.refresh(document)

return documents


@pytest_asyncio.fixture
async def chats(session: Session, documents: List[Document]) -> List[Chat]:
chats = [
Expand All @@ -93,3 +109,18 @@ async def chats(session: Session, documents: List[Document]) -> List[Chat]:
session.commit()

return chats


@pytest_asyncio.fixture
async def semi_deleted_chats(session: Session, documents: List[Document]) -> List[Chat]:
chats = [
Chat(document_id=documents[0].id, title="Chat 1"),
Chat(document_id=documents[0].id, title="Chat 2", deleted_at=datetime.datetime.now())
]

session.add_all(chats)
session.commit()
for chat in chats:
session.refresh(chat)

return chats
Loading