Skip to content

Commit 326d731

Browse files
committed
improving citation output prep
1 parent cd976ff commit 326d731

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

willa/chatbot/graph_manager.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Manages the shared state and workflow for Willa chatbots."""
2+
import re
23
from typing import Any, Optional, Annotated, NotRequired
34
from typing_extensions import TypedDict
45

@@ -19,12 +20,10 @@ class WillaChatbotState(TypedDict):
1920
messages: Annotated[list[AnyMessage], add_messages]
2021
filtered_messages: NotRequired[list[AnyMessage]]
2122
summarized_messages: NotRequired[list[AnyMessage]]
22-
docs_context: NotRequired[str]
2323
search_query: NotRequired[str]
2424
tind_metadata: NotRequired[str]
2525
documents: NotRequired[list[Any]]
2626
citations: NotRequired[list[dict[str, Any]]]
27-
context: NotRequired[dict[str, Any]]
2827

2928

3029
class GraphManager: # pylint: disable=too-few-public-methods
@@ -91,25 +90,27 @@ def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
9190
vector_store = self._vector_store
9291

9392
if not search_query or not vector_store:
94-
return {"docs_context": "", "tind_metadata": "", "documents": []}
93+
return {"tind_metadata": "", "documents": []}
9594

9695
# Search for relevant documents
9796
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
9897
matching_docs = retriever.invoke(search_query)
9998
formatted_documents = [
10099
{
100+
"id": f"{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}_{i}",
101101
"page_content": doc.page_content,
102-
"start_index": str(doc.metadata.get('start_index')) if doc.metadata.get('start_index') else '',
103-
"total_pages": str(doc.metadata.get('total_pages')) if doc.metadata.get('total_pages') else '',
102+
"title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0],
103+
"project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0],
104+
"tind_link": format_tind_context.get_tind_url(
105+
doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0])
104106
}
105-
for doc in matching_docs
107+
for i, doc in enumerate(matching_docs, 1)
106108
]
107109

108-
# Format context and metadata
109-
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
110+
# Format tind metadata
110111
tind_metadata = format_tind_context.get_tind_context(matching_docs)
111112

112-
return {"docs_context": docs_context, "tind_metadata": tind_metadata, "documents": formatted_documents}
113+
return {"tind_metadata": tind_metadata, "documents": formatted_documents}
113114

114115
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
115116
"""Prepare the current and past messages for response generation."""
@@ -154,7 +155,9 @@ def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMess
154155
state['citations'] = citations
155156
response_content += "\n\nCitations:\n"
156157
for citation in citations:
157-
response_content += f"- {citation.get('text', '')} (docs: {citation.get('document_ids', [])})\n"
158+
doc_ids = list(dict.fromkeys([re.sub(r'_\d*$', '', doc_id)
159+
for doc_id in citation.get('document_ids', [])]))
160+
response_content += f"- {citation.get('text', '')} ({', '.join(doc_ids)})\n"
158161

159162
response_messages: list[AnyMessage] = [AIMessage(content=response_content),
160163
ChatMessage(content=tind_metadata, role='TIND',

0 commit comments

Comments
 (0)