From 57619c86cc09a37e5543b8bd8365cbb5da7c4b71 Mon Sep 17 00:00:00 2001 From: Simon Podhajsky Date: Mon, 7 Aug 2023 16:15:39 +0200 Subject: [PATCH 1/2] Prepare backend deletion document/chat logic --- backend/app/api/document_router.py | 79 +++++++++- backend/app/models/document.py | 3 + backend/app/test/conftest.py | 31 ++++ backend/app/test/test_document_routes.py | 188 ++++++++++++++++++++--- 4 files changed, 276 insertions(+), 25 deletions(-) diff --git a/backend/app/api/document_router.py b/backend/app/api/document_router.py index 3b58c16..c9138f8 100644 --- a/backend/app/api/document_router.py +++ b/backend/app/api/document_router.py @@ -1,3 +1,4 @@ +import datetime from fastapi import ( APIRouter, Body, @@ -9,11 +10,13 @@ UploadFile, ) from fastapi.responses import JSONResponse, RedirectResponse +from http import HTTPStatus from langchain.embeddings.openai import OpenAIEmbeddings from langchain.document_loaders import UnstructuredPDFLoader from langchain.text_splitter import CharacterTextSplitter import os +from sqlalchemy.sql.operators import is_ from sqlmodel import Session, col, select import tempfile from typing import Annotated, Optional, Sequence, Union @@ -109,14 +112,42 @@ async def upload_and_process_file( async def list_documents( session: Session = Depends(get_session), current_user: User = Depends(get_current_user), -) -> Sequence[Document]: +) -> Sequence[DocumentWithChats]: query = ( select(Document) - .where(Document.user_id == current_user.id) + .where(Document.user_id == current_user.id, is_(Document.deleted_at, None)) .order_by(col(Document.created_at).desc()) ) documents = list(session.exec(query)) - return documents + + # Create a new list with filtered chats + docs_with_filtered_chats = [ + DocumentWithChats( + **doc.dict(), chats=[chat for chat in doc.chats if chat.deleted_at is None] + ) + for doc in documents + ] + + return docs_with_filtered_chats + + +# Delete uploaded document (set a deleted_at column to current timestamp) +@document_router.delete("/{document_id}") +async def delete_document( + document_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> Response: + query = select(Document).where( + Document.user_id == current_user.id, Document.id == document_id + ) + document = session.exec(query).first() + if document is None: + raise HTTPException(status_code=404, detail="Document not found") + document.deleted_at = datetime.datetime.utcnow() + session.add(document) + session.commit() + return Response(status_code=HTTPStatus.NO_CONTENT) @document_router.get("/{document_id}/new_chat") @@ -126,6 +157,11 @@ async def create_new_chat( current_user: User = Depends(get_current_user), ) -> int: # TODO: Check that the document belongs to the user + document = session.get(Document, document_id) + if document.user_id != current_user.id: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, detail="Access not permitted" + ) new_chat = Chat(document_id=document_id) session.add(new_chat) session.commit() @@ -134,6 +170,32 @@ async def create_new_chat( return new_chat.id # type: ignore +@document_router.delete("/{document_id}/chat/{chat_id}") +async def delete_chat( + document_id: int, + chat_id: int, + session: Session = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> Response: + query = ( + select(Chat) + .join(Document) + .where( + Chat.id == chat_id, + Chat.document_id == document_id, + is_(Chat.deleted_at, None), + Document.user_id == current_user.id, + ) + ) + chat = session.exec(query).first() + if chat is None: + raise HTTPException(status_code=404, detail="Chat not found") + chat.deleted_at = datetime.datetime.utcnow() + session.add(chat) + session.commit() + return Response(status_code=HTTPStatus.NO_CONTENT) + + @document_router.get("/{document_id}/chat/{chat_id}") async def retrieve_chat( document_id: int, @@ -142,7 +204,9 @@ async def retrieve_chat( current_user: User = Depends(get_current_user), ) -> Sequence[ChatMessage]: query = select(ChatMessage).where( - ChatMessage.chat_id == chat_id, ChatMessage.user_id == current_user.id + ChatMessage.chat_id == chat_id, + ChatMessage.user_id == current_user.id, + is_(ChatMessage.deleted_at, None), ) chat_messages = session.exec(query) return list(chat_messages) @@ -157,7 +221,10 @@ def get_k_similar_chunks( ) -> Sequence[str]: query = ( select(VectorEmbedding.content) - .where(VectorEmbedding.document_id == document_id) + .where( + VectorEmbedding.document_id == document_id, + # is_(VectorEmbedding.deleted_at, None), + ) .order_by(VectorEmbedding.embedding.l2_distance(query_embedding)) # type: ignore .limit(k) ) @@ -201,7 +268,7 @@ async def get_ai_response( chat = ChatOpenAI(model=model, temperature=temperature) # type: ignore previous_messages_query = ( select(ChatMessage) - .where(ChatMessage.chat_id == chat_id) + .where(ChatMessage.chat_id == chat_id, is_(ChatMessage.deleted_at, None)) .order_by(col(ChatMessage.created_at).desc()) .limit(2) ) diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 11b69a4..7de0dc1 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -31,6 +31,7 @@ class VectorEmbedding(SQLModel, table=True): content: str = Field(...) created_at: datetime = Field(default_factory=datetime.now) + # deleted_at: datetime = Field(default=None) class ChatOriginator(str, Enum): @@ -42,6 +43,8 @@ class Chat(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) document_id: int = Field(..., foreign_key="document.id") document: Document = Relationship(back_populates="chats") + created_at: datetime = Field(default_factory=datetime.now) + deleted_at: Optional[datetime] = Field(default=None) class ChatMessage(SQLModel, table=True): diff --git a/backend/app/test/conftest.py b/backend/app/test/conftest.py index af97ebe..b3eb68f 100644 --- a/backend/app/test/conftest.py +++ b/backend/app/test/conftest.py @@ -1,5 +1,6 @@ # Setup per https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/ +import datetime from fastapi.testclient import TestClient import pytest import pytest_asyncio @@ -71,6 +72,21 @@ async def documents(session: Session, user: User) -> List[Document]: return documents +@pytest_asyncio.fixture +async def semi_deleted_documents(session: Session, user: User) -> List[Document]: + documents = [ + Document(title="Document 1", user_id=user.id), + Document(title="Document 2", user_id=user.id, deleted_at=datetime.datetime.now()), + ] + + session.add_all(documents) + session.commit() + for document in documents: + session.refresh(document) + + return documents + + @pytest_asyncio.fixture async def chats(session: Session, documents: List[Document]) -> List[Chat]: chats = [ @@ -93,3 +109,18 @@ async def chats(session: Session, documents: List[Document]) -> List[Chat]: session.commit() return chats + + +@pytest_asyncio.fixture +async def semi_deleted_chats(session: Session, documents: List[Document]) -> List[Chat]: + chats = [ + Chat(document_id=documents[0].id, title="Chat 1"), + Chat(document_id=documents[0].id, title="Chat 2", deleted_at=datetime.datetime.now()) + ] + + session.add_all(chats) + session.commit() + for chat in chats: + session.refresh(chat) + + return chats \ No newline at end of file diff --git a/backend/app/test/test_document_routes.py b/backend/app/test/test_document_routes.py index a0bedac..984b191 100644 --- a/backend/app/test/test_document_routes.py +++ b/backend/app/test/test_document_routes.py @@ -1,3 +1,4 @@ +import datetime from http import HTTPStatus from io import BytesIO from typing import List @@ -109,7 +110,7 @@ async def test_list_messages_in_chat_unauthorized( # assert response.status_code == HTTPStatus.OK # # FIXME: Should return document instead # assert isinstance(resp_json, int) - + # document = session.get(Document, resp_json) # assert document.title == "sample.pdf" # assert document.user_id == 1 @@ -128,12 +129,14 @@ async def test_upload_file_unauthorized(client: TestClient, session: Session): new_docs_count = len(session.exec(select(Document)).all()) assert new_docs_count == docs_count assert response.status_code == HTTPStatus.UNAUTHORIZED - - + + @pytest.mark.asyncio -async def test_create_new_chat(client: TestClient, session: Session, documents: List[Document]): +async def test_create_new_chat( + client: TestClient, session: Session, documents: List[Document] +): document = documents[0] # choose the first document - + # override get_current_user dependency app.dependency_overrides[get_current_user] = mock_get_current_user response = client.get(f"/documents/{document.id}/new_chat") @@ -153,7 +156,9 @@ async def test_create_new_chat(client: TestClient, session: Session, documents: @pytest.mark.asyncio -async def test_create_new_chat_unauthenticated(client: TestClient, session: Session, documents: List[Document]): +async def test_create_new_chat_unauthenticated( + client: TestClient, session: Session, documents: List[Document] +): document = documents[0] # choose the first document response = client.get(f"/documents/{document.id}/new_chat") assert response.status_code == HTTPStatus.UNAUTHORIZED @@ -162,36 +167,181 @@ async def test_create_new_chat_unauthenticated(client: TestClient, session: Sess async def mock_embed_message(message): return [0.9] * 1536 + # Define a mock for get_ai_response function async def mock_get_ai_response(chat_id, session, gpt_prompt, message): from langchain.schema import AIMessage + return AIMessage(content="Mock AI response") + # Define a mock for the vector store retrieval function, since it can't be # done with SQLite def mock_get_similar_chunks(query_embedding, document_id, session, k=3): - return [ - f"Mock similar chunk {n}" - for n in range(0, k) - ] + return [f"Mock similar chunk {n}" for n in range(0, k)] + @pytest.mark.asyncio -@patch('app.api.document_router.embed_message', new=mock_embed_message) -@patch('app.api.document_router.get_ai_response', new=mock_get_ai_response) -@patch('app.api.document_router.get_k_similar_chunks', new=mock_get_similar_chunks) -@patch.dict(os.environ, {"OPENAI_API_KEY": "mock_key"}) # Check that key isn't actually being used -async def test_send_message(client: TestClient, session: Session, chats: List[Chat], user: User): +@patch("app.api.document_router.embed_message", new=mock_embed_message) +@patch("app.api.document_router.get_ai_response", new=mock_get_ai_response) +@patch("app.api.document_router.get_k_similar_chunks", new=mock_get_similar_chunks) +@patch.dict( + os.environ, {"OPENAI_API_KEY": "mock_key"} +) # Check that key isn't actually being used +async def test_send_message( + client: TestClient, session: Session, chats: List[Chat], user: User +): app.dependency_overrides[get_current_user] = mock_get_current_user chat = chats[0] # choose the first chat document_id = chat.document_id message = "Test message content" - response = client.post(f"/documents/{document_id}/chat/{chat.id}/message", json={"message": message}) + response = client.post( + f"/documents/{document_id}/chat/{chat.id}/message", json={"message": message} + ) app.dependency_overrides.clear() assert response.status_code == HTTPStatus.OK chat_message = response.json() - assert chat_message['content'] == "Mock AI response" - assert chat_message['chat_id'] == chat.id - assert chat_message['user_id'] == user.id \ No newline at end of file + assert chat_message["content"] == "Mock AI response" + assert chat_message["chat_id"] == chat.id + assert chat_message["user_id"] == user.id + + +# Test deletions of documents, chats, and chat messages +@pytest.mark.asyncio +async def test_delete_document( + client: TestClient, session: Session, documents: List[Document], chats: List[Chat] +): + document = documents[0] + assert document.deleted_at is None + + app.dependency_overrides[get_current_user] = mock_get_current_user + response = client.delete(f"/documents/{document.id}") + app.dependency_overrides.clear() + + assert response.status_code == HTTPStatus.NO_CONTENT + + updated_document = session.get(Document, document.id) + assert updated_document.deleted_at is not None + + # document_chats = session.exec( + # select(Chat).where(Chat.document_id == document.id) + # ).all() + # assert len(document_chats) > 0 + # for chat in document_chats: + # assert chat.deleted_at is not None + # chat_messages = session.exec( + # select(ChatMessage).where(ChatMessage.chat_id == chat.id) + # ).all() + # assert len(chat_messages) > 0 + # for chat_message in chat_messages: + # assert chat_message.deleted_at is not None + + +@pytest.mark.asyncio +async def test_cannot_delete_other_user_document( + client: TestClient, session: Session, documents: List[Document] +): + document = documents[0] + assert document.deleted_at is None + + app.dependency_overrides[get_current_user] = mock_get_another_user + response = client.delete(f"/documents/{document.id}") + app.dependency_overrides.clear() + + assert response.status_code == HTTPStatus.NOT_FOUND + + updated_document = session.get(Document, document.id) + assert updated_document.deleted_at is None + + # TODO: Test that all inheriting chats and chat messages are not deleted either? + + +@pytest.mark.asyncio +async def test_delete_chat(client: TestClient, session: Session, chats: List[Chat]): + chat = chats[0] + assert chat.deleted_at is None + + # TODO: Retrieve chat messages and test that they're not deleted either + app.dependency_overrides[get_current_user] = mock_get_current_user + response = client.delete(f"/documents/{chat.document_id}/chat/{chat.id}") + app.dependency_overrides.clear() + + assert response.status_code == HTTPStatus.NO_CONTENT + + updated_chat = session.get(Chat, chat.id) + assert updated_chat.deleted_at is not None + + # # TODO: All chat messages should also be marked deleted + # chat_messages = session.exec(select(ChatMessage).where(ChatMessage.chat_id == chat.id)).all() + # assert len(chat_messages) > 0 + # for chat_message in chat_messages: + # assert chat_message.deleted_at is not None + + +@pytest.mark.asyncio +async def test_cannot_delete_other_user_chat( + client: TestClient, session: Session, chats: List[Chat] +): + chat = chats[0] + assert chat.deleted_at is None + + app.dependency_overrides[get_current_user] = mock_get_another_user + response = client.delete(f"/documents/{chat.document_id}/chat/{chat.id}") + app.dependency_overrides.clear() + + assert response.status_code == HTTPStatus.NOT_FOUND + + updated_chat = session.get(Chat, chat.id) + assert updated_chat.deleted_at is None + + +# TODO: Test document and chat retrieval to see if they're including deleted documents and chats/chat messages +@pytest.mark.asyncio +async def test_list_documents_omits_deleted_docs( + client: TestClient, session: Session, semi_deleted_documents: List[Document] +): + app.dependency_overrides[get_current_user] = mock_get_current_user + response = client.get("/documents") + app.dependency_overrides.clear() + + assert response.status_code == HTTPStatus.OK + + document_list = response.json() + assert len(document_list) == 1 + + +@pytest.mark.asyncio +async def test_list_documents_omits_deleted_chats( + client: TestClient, session: Session, documents: List[Document], semi_deleted_chats: List[Chat] +): + app.dependency_overrides[get_current_user] = mock_get_current_user + response = client.get("/documents") + app.dependency_overrides.clear() + + docs = response.json() + assert len(docs[0]["chats"]) == 1 + + +# TODO: Implement message deletion? +# @pytest.mark.asyncio +# def test_delete_message( +# client: TestClient, session: Session, chats: List[Chat] +# ): +# chat = chats[0] +# messages = session.exec(select(ChatMessage).where(ChatMessage.chat_id == chat.id)).all() +# message = messages[0] +# assert message.deleted_at is None + +# app.dependency_overrides[get_current_user] = mock_get_current_user +# response = client.delete( +# f"/documents/{chat.document_id}/chat/{chat.id}/message/{message.id}" +# ) +# app.dependency_overrides.clear() + +# assert response.status_code == HTTPStatus.NO_CONTENT + +# updated_message = session.get(ChatMessage, message.id) +# assert updated_message.deleted_at is not None \ No newline at end of file From 8a8ce10c9af9e94a965406f5d180847fffab1df7 Mon Sep 17 00:00:00 2001 From: Simon Podhajsky Date: Mon, 7 Aug 2023 16:22:39 +0200 Subject: [PATCH 2/2] Fix the intermittently failing test --- backend/app/test/test_document_routes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/app/test/test_document_routes.py b/backend/app/test/test_document_routes.py index 984b191..c3d11bf 100644 --- a/backend/app/test/test_document_routes.py +++ b/backend/app/test/test_document_routes.py @@ -317,12 +317,16 @@ async def test_list_documents_omits_deleted_docs( async def test_list_documents_omits_deleted_chats( client: TestClient, session: Session, documents: List[Document], semi_deleted_chats: List[Chat] ): + assert semi_deleted_chats[1].deleted_at is not None + app.dependency_overrides[get_current_user] = mock_get_current_user response = client.get("/documents") app.dependency_overrides.clear() docs = response.json() - assert len(docs[0]["chats"]) == 1 + doc_with_deleted_chat = [doc for doc in docs if doc["id"] == semi_deleted_chats[1].document_id] + assert len(doc_with_deleted_chat[0]["chats"]) == 1 + assert doc_with_deleted_chat[0]["chats"][0]["deleted_at"] is None # TODO: Implement message deletion?