Skip to content

Commit 7d19dc6

Browse files
authored
Merge pull request #42 from TogetherCrew/fix/41-add-lanchain-tool-call
fix: using langchain tool calling!
2 parents 5430add + 1fa6d64 commit 7d19dc6

File tree

4 files changed

+103
-93
lines changed

4 files changed

+103
-93
lines changed

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ crewai==0.105.0
55
tc-temporal-backend==1.1.4
66
transformers[torch]==4.49.0
77
nest-asyncio==1.6.0
8-
openai==1.66.3
8+
openai==1.93.0
99
tc-hivemind-backend==1.4.3
10+
langchain==0.3.26
11+
langchain-openai==0.3.27

tasks/agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,21 @@ async def run_hivemind_agent_activity(
134134
}
135135
)
136136

137+
error_fallback_answer = "Looks like things didn't go through. Please give it another go."
137138
if isinstance(final_answer, str) and "encountered an error" in final_answer.lower():
138139
logging.error(f"final_answer: {final_answer}")
139-
fallback_answer = "Looks like things didn't go through. Please give it another go."
140-
140+
141141
# Update step: Error handling
142142
mongo_persistence.update_workflow_step(
143143
workflow_id=workflow_id,
144144
step_name="error_handling",
145145
step_data={
146146
"errorType": "crewai_error",
147147
"originalAnswer": final_answer,
148-
"fallbackAnswer": fallback_answer,
148+
"fallbackAnswer": error_fallback_answer,
149149
}
150150
)
151-
final_answer = fallback_answer
151+
final_answer = error_fallback_answer
152152

153153
if memory and final_answer != "NONE":
154154
chat = f"User: {payload.query}\nAgent: {final_answer}"
@@ -178,7 +178,7 @@ async def run_hivemind_agent_activity(
178178
status="completed_no_answer"
179179
)
180180

181-
if final_answer == "NONE":
181+
if final_answer == "NONE" or final_answer == error_fallback_answer:
182182
return None
183183
else:
184184
return final_answer

tasks/hivemind/agent.py

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from crewai.flow.flow import Flow, listen, start, router
55
from crewai.llm import LLM
66
from tasks.hivemind.classify_question import ClassifyQuestion
7-
from tasks.hivemind.query_data_sources import RAGPipelineTool
8-
from crewai.process import Process
7+
from tasks.hivemind.query_data_sources import make_rag_tool
98
from pydantic import BaseModel
109
from crewai.tools import tool
1110
from openai import OpenAI
1211
from typing import Optional
1312
from tasks.mongo_persistence import MongoPersistence
13+
from langchain_openai import ChatOpenAI
14+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15+
from langchain.agents import AgentExecutor, create_openai_functions_agent
1416

1517

1618
class AgenticFlowState(BaseModel):
@@ -154,50 +156,76 @@ def detect_question_type(self) -> str:
154156

155157
@router("rag")
156158
def do_rag_query(self) -> str:
157-
query_data_source_tool = RAGPipelineTool.setup_tools(
158-
community_id=self.community_id,
159-
enable_answer_skipping=self.enable_answer_skipping,
160-
workflow_id=self.workflow_id,
161-
)
159+
# query_data_source_tool = RAGPipelineTool.setup_tools(
160+
# community_id=self.community_id,
161+
# enable_answer_skipping=self.enable_answer_skipping,
162+
# workflow_id=self.workflow_id,
163+
# )
164+
165+
# q_a_bot_agent = Agent(
166+
# role="Q&A Bot",
167+
# goal=(
168+
# "You decide when to rely on your internal knowledge and when to retrieve real-time data. "
169+
# "For queries that are not specific to community data, answer using your own LLM knowledge. "
170+
# "Your final response must not exceed 250 words."
171+
# ),
172+
# backstory=(
173+
# "You are an intelligent agent capable of giving concise answers to questions."
174+
# ),
175+
# allow_delegation=True,
176+
# llm=LLM(model="gpt-4o-mini-2024-07-18"),
177+
# )
178+
# rag_task = Task(
179+
# description=(
180+
# "Answer the following query using a maximum of 250 words. "
181+
# "If the query is specific to community data, use the tool to retrieve updated information; "
182+
# f"otherwise, answer using your internal knowledge.\n\nQuery: {self.state.user_query}"
183+
# ),
184+
# expected_output="A clear, well-structured answer under 250 words that directly addresses the query using appropriate information sources",
185+
# agent=q_a_bot_agent,
186+
# tools=[
187+
# query_data_source_tool(result_as_answer=True),
188+
# ],
189+
# )
190+
191+
# crew = Crew(
192+
# agents=[q_a_bot_agent],
193+
# tasks=[rag_task],
194+
# process=Process.hierarchical,
195+
# manager_llm=LLM(model="gpt-4o-mini-2024-07-18"),
196+
# verbose=True,
197+
# )
198+
199+
# crew_output = crew.kickoff()
162200

163-
q_a_bot_agent = Agent(
164-
role="Q&A Bot",
165-
goal=(
166-
"You decide when to rely on your internal knowledge and when to retrieve real-time data. "
167-
"For queries that are not specific to community data, answer using your own LLM knowledge. "
168-
"Your final response must not exceed 250 words."
169-
),
170-
backstory=(
171-
"You are an intelligent agent capable of giving concise answers to questions."
172-
),
173-
allow_delegation=True,
174-
llm=LLM(model="gpt-4o-mini-2024-07-18"),
175-
)
176-
rag_task = Task(
177-
description=(
178-
"Answer the following query using a maximum of 250 words. "
179-
"If the query is specific to community data, use the tool to retrieve updated information; "
180-
f"otherwise, answer using your internal knowledge.\n\nQuery: {self.state.user_query}"
181-
),
182-
expected_output="A clear, well-structured answer under 250 words that directly addresses the query using appropriate information sources",
183-
agent=q_a_bot_agent,
184-
tools=[
185-
query_data_source_tool(result_as_answer=True),
186-
],
187-
)
201+
# Store the latest crew output and increment retry count
202+
# self.state.last_answer = crew_output
203+
204+
llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18")
205+
rag_tool = make_rag_tool(self.enable_answer_skipping, self.community_id, self.workflow_id)
206+
tools = [rag_tool]
188207

189-
crew = Crew(
190-
agents=[q_a_bot_agent],
191-
tasks=[rag_task],
192-
process=Process.hierarchical,
193-
manager_llm=LLM(model="gpt-4o-mini-2024-07-18"),
194-
verbose=True,
208+
SYSTEM_INSTRUCTIONS = """\
209+
You are a helpful assistant.
210+
"""
211+
212+
prompt = ChatPromptTemplate.from_messages(
213+
[
214+
("system", SYSTEM_INSTRUCTIONS),
215+
MessagesPlaceholder("chat_history", optional=True),
216+
("human", "{input}"),
217+
MessagesPlaceholder("agent_scratchpad"),
218+
]
195219
)
220+
agent = create_openai_functions_agent(llm, tools, prompt)
196221

197-
crew_output = crew.kickoff()
222+
# Run the agent
223+
agent_executor = AgentExecutor(
224+
agent=agent, tools=tools, verbose=True, return_intermediate_steps=False
225+
)
198226

199-
# Store the latest crew output and increment retry count
200-
self.state.last_answer = crew_output
227+
result = agent_executor.invoke({"input": self.state.user_query})
228+
self.state.last_answer = result["output"]
201229
self.state.retry_count += 1
202230

203231
return "stop"
Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import asyncio
22
import os
3-
from uuid import uuid1
43

54
import nest_asyncio
65
from dotenv import load_dotenv
7-
from typing import Type, Optional
6+
from typing import Optional, Callable
7+
from langchain.tools import tool
88
from tc_temporal_backend.client import TemporalClient
99
from tc_temporal_backend.schema.hivemind import HivemindQueryPayload
10-
from pydantic import BaseModel, Field
1110

1211
nest_asyncio.apply()
1312

14-
from crewai.tools import BaseTool
1513

1614

1715
class QueryDataSources:
@@ -64,55 +62,34 @@ def load_hivemind_queue(self) -> str:
6462
return hivemind_queue
6563

6664

67-
class RAGPipelineToolSchema(BaseModel):
68-
"""Input schema for RAGPipelineTool."""
65+
def make_rag_tool(enable_answer_skipping: bool, community_id: str, workflow_id: Optional[str] = None) -> Callable:
66+
"""
67+
Make the RAG pipeline tool.
68+
Passing the arguments to the tool instead of relying on the LLM to pass them (making the work for LLM easier)
6969
70-
query: str = Field(
71-
...,
72-
description=(
73-
"The input query string provided by the user. The name is case sensitive."
74-
"Please provide a value of type string. This parameter is required."
75-
),
76-
)
70+
Args:
71+
enable_answer_skipping (bool): The flag to enable answer skipping.
72+
community_id (str): The community ID.
73+
workflow_id (Optional[str]): The workflow ID.
7774
78-
79-
class RAGPipelineTool(BaseTool):
80-
name: str = "RAG pipeline tool"
81-
description: str = (
82-
"This tool implements a Retrieval-Augmented Generation (RAG) pipeline which "
83-
"queries available data sources to provide accurate answers to user queries. "
84-
)
85-
args_schema: Type[BaseModel] = RAGPipelineToolSchema
86-
87-
@classmethod
88-
def setup_tools(cls, community_id: str, enable_answer_skipping: bool, workflow_id: Optional[str] = None):
89-
"""
90-
Setup the tool with the necessary community identifier, the flag to enable answer skipping,
91-
and the workflow ID for tracking.
75+
Returns:
76+
Callable: The RAG pipeline tool.
77+
"""
78+
@tool(return_direct=True)
79+
def get_rag_answer(query: str) -> str:
9280
"""
93-
cls.community_id = community_id
94-
cls.enable_answer_skipping = enable_answer_skipping
95-
cls.workflow_id = workflow_id
96-
return cls
81+
Get the answer from the RAG pipeline
9782
98-
def _run(self, query: str) -> str:
99-
"""
100-
Execute the RAG pipeline by querying the available data sources.
101-
102-
Parameters
103-
------------
104-
query : str
105-
The input query string provided by the user.
83+
Args:
84+
query (str): The input query string provided by the user.
10685
107-
Returns
108-
----------
109-
response : str
110-
The response obtained after querying the data sources.
86+
Returns:
87+
str: The answer to the query.
11188
"""
11289
query_data_sources = QueryDataSources(
113-
community_id=self.community_id,
114-
enable_answer_skipping=self.enable_answer_skipping,
115-
workflow_id=self.workflow_id,
90+
community_id=community_id,
91+
enable_answer_skipping=enable_answer_skipping,
92+
workflow_id=workflow_id,
11693
)
11794
response = asyncio.run(query_data_sources.query(query))
11895

@@ -121,3 +98,6 @@ def _run(self, query: str) -> str:
12198
return "NONE"
12299
else:
123100
return response
101+
102+
# returing the tool function
103+
return get_rag_answer

0 commit comments

Comments
 (0)