From 9bb6f38c5a8d4ff3196b117674721845bfb11f6d Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Sat, 26 Jul 2025 11:37:28 -0700 Subject: [PATCH 1/2] Port continue-as-new fix for customer service workflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add case-insensitive string matching in FAQ lookup tool - Update init_agents to return tuple with agent map for continue-as-new state management - Implement continue-as-new functionality with proper state serialization - Fix client output to skip duplicate user message on print 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../customer_service/customer_service.py | 17 +++- .../run_customer_service_client.py | 2 +- .../workflows/customer_service_workflow.py | 92 +++++++++++++++---- 3 files changed, 85 insertions(+), 26 deletions(-) diff --git a/openai_agents/customer_service/customer_service.py b/openai_agents/customer_service/customer_service.py index 6a08f4ed..88f6e3cd 100644 --- a/openai_agents/customer_service/customer_service.py +++ b/openai_agents/customer_service/customer_service.py @@ -1,5 +1,7 @@ from __future__ import annotations as _annotations +from typing import Dict, Tuple + from agents import Agent, RunContextWrapper, function_tool, handoff from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX from pydantic import BaseModel @@ -23,19 +25,20 @@ class AirlineAgentContext(BaseModel): description_override="Lookup frequently asked questions.", ) async def faq_lookup_tool(question: str) -> str: - if "bag" in question or "baggage" in question: + question_lower = question.lower() + if "bag" in question_lower or "baggage" in question_lower: return ( "You are allowed to bring one bag on the plane. " "It must be under 50 pounds and 22 inches x 14 inches x 9 inches." ) - elif "seats" in question or "plane" in question: + elif "seats" in question_lower or "plane" in question_lower: return ( "There are 120 seats on the plane. " "There are 22 business class seats and 98 economy seats. " "Exit rows are rows 4 and 16. " "Rows 5-8 are Economy Plus, with extra legroom. " ) - elif "wifi" in question: + elif "wifi" in question_lower: return "We have free wifi on the plane, join Airline-Wifi" return "I'm sorry, I don't know the answer to that question." @@ -74,7 +77,9 @@ async def on_seat_booking_handoff( ### AGENTS -def init_agents() -> Agent[AirlineAgentContext]: +def init_agents() -> Tuple[ + Agent[AirlineAgentContext], Dict[str, Agent[AirlineAgentContext]] +]: """ Initialize the agents for the airline customer service workflow. :return: triage agent @@ -121,7 +126,9 @@ def init_agents() -> Agent[AirlineAgentContext]: faq_agent.handoffs.append(triage_agent) seat_booking_agent.handoffs.append(triage_agent) - return triage_agent + return triage_agent, { + agent.name: agent for agent in [faq_agent, seat_booking_agent, triage_agent] + } class ProcessUserMessageInput(BaseModel): diff --git a/openai_agents/customer_service/run_customer_service_client.py b/openai_agents/customer_service/run_customer_service_client.py index e66419e4..044e0775 100644 --- a/openai_agents/customer_service/run_customer_service_client.py +++ b/openai_agents/customer_service/run_customer_service_client.py @@ -67,7 +67,7 @@ async def main(): CustomerServiceWorkflow.process_user_message, message_input ) history.extend(new_history) - print(*new_history, sep="\n") + print(*new_history[1:], sep="\n") except WorkflowUpdateFailedError: print("** Stale conversation. Reloading...") length = len(history) diff --git a/openai_agents/customer_service/workflows/customer_service_workflow.py b/openai_agents/customer_service/workflows/customer_service_workflow.py index c816d868..a5c20f61 100644 --- a/openai_agents/customer_service/workflows/customer_service_workflow.py +++ b/openai_agents/customer_service/workflows/customer_service_workflow.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from agents import ( - Agent, + HandoffCallItem, HandoffOutputItem, ItemHelpers, MessageOutputItem, @@ -12,6 +12,7 @@ TResponseInputItem, trace, ) +from pydantic import dataclasses from temporalio import workflow from openai_agents.customer_service.customer_service import ( @@ -21,32 +22,77 @@ ) +@dataclasses.dataclass +class CustomerServiceWorkflowState: + printed_history: list[str] + current_agent_name: str + context: AirlineAgentContext + input_items: list[TResponseInputItem] + + @workflow.defn class CustomerServiceWorkflow: @workflow.init - def __init__(self, input_items: list[TResponseInputItem] | None = None): + def __init__( + self, customer_service_state: CustomerServiceWorkflowState | None = None + ): self.run_config = RunConfig() - self.chat_history: list[str] = [] - self.current_agent: Agent[AirlineAgentContext] = init_agents() - self.context = AirlineAgentContext() - self.input_items = [] if input_items is None else input_items + + starting_agent, self.agent_map = init_agents() + self.current_agent = ( + self.agent_map[customer_service_state.current_agent_name] + if customer_service_state + else starting_agent + ) + self.context = ( + customer_service_state.context + if customer_service_state + else AirlineAgentContext() + ) + self.printed_history: list[str] = ( + customer_service_state.printed_history if customer_service_state else [] + ) + self.input_items = ( + customer_service_state.input_items if customer_service_state else [] + ) + self.continue_as_new_suggested = False @workflow.run - async def run(self, input_items: list[TResponseInputItem] | None = None): + async def run( + self, customer_service_state: CustomerServiceWorkflowState | None = None + ): await workflow.wait_condition( - lambda: workflow.info().is_continue_as_new_suggested() + # lambda: workflow.info().is_continue_as_new_suggested() + # and + lambda: self.continue_as_new_suggested and workflow.all_handlers_finished() ) - workflow.continue_as_new(self.input_items) + # Convert input_items to plain dictionaries for serialization + serializable_input_items = [] + for item in self.input_items: + if hasattr(item, "model_dump") and callable(getattr(item, "model_dump")): + # Convert Pydantic objects to dictionaries + serializable_input_items.append(item.model_dump()) # type: ignore + else: + # Already a plain Python object + serializable_input_items.append(item) + workflow.continue_as_new( + CustomerServiceWorkflowState( + printed_history=self.printed_history, + current_agent_name=self.current_agent.name, + context=self.context, + input_items=serializable_input_items, + ) + ) @workflow.query def get_chat_history(self) -> list[str]: - return self.chat_history + return self.printed_history @workflow.update async def process_user_message(self, input: ProcessUserMessageInput) -> list[str]: - length = len(self.chat_history) - self.chat_history.append(f"User: {input.user_input}") + length = len(self.printed_history) + self.printed_history.append(f"User: {input.user_input}") with trace("Customer service", group_id=workflow.info().workflow_id): self.input_items.append({"content": input.user_input, "role": "user"}) result = await Runner.run( @@ -59,27 +105,33 @@ async def process_user_message(self, input: ProcessUserMessageInput) -> list[str for new_item in result.new_items: agent_name = new_item.agent.name if isinstance(new_item, MessageOutputItem): - self.chat_history.append( + self.printed_history.append( f"{agent_name}: {ItemHelpers.text_message_output(new_item)}" ) elif isinstance(new_item, HandoffOutputItem): - self.chat_history.append( + self.printed_history.append( f"Handed off from {new_item.source_agent.name} to {new_item.target_agent.name}" ) + elif isinstance(new_item, HandoffCallItem): + self.printed_history.append( + f"{agent_name}: Handed off to tool {new_item.raw_item.name}" + ) elif isinstance(new_item, ToolCallItem): - self.chat_history.append(f"{agent_name}: Calling a tool") + self.printed_history.append(f"{agent_name}: Calling a tool") elif isinstance(new_item, ToolCallOutputItem): - self.chat_history.append( + self.printed_history.append( f"{agent_name}: Tool call output: {new_item.output}" ) else: - self.chat_history.append( + self.printed_history.append( f"{agent_name}: Skipping item: {new_item.__class__.__name__}" ) self.input_items = result.to_input_list() self.current_agent = result.last_agent - workflow.set_current_details("\n\n".join(self.chat_history)) - return self.chat_history[length:] + workflow.set_current_details("\n\n".join(self.printed_history)) + + self.continue_as_new_suggested = True + return self.printed_history[length:] @process_user_message.validator def validate_process_user_message(self, input: ProcessUserMessageInput) -> None: @@ -87,5 +139,5 @@ def validate_process_user_message(self, input: ProcessUserMessageInput) -> None: raise ValueError("User input cannot be empty.") if len(input.user_input) > 1000: raise ValueError("User input is too long. Please limit to 1000 characters.") - if input.chat_length != len(self.chat_history): + if input.chat_length != len(self.printed_history): raise ValueError("Stale chat history. Please refresh the chat.") From 0a6af6276cdbae3c468b7756f4b474c529eb037f Mon Sep 17 00:00:00 2001 From: Johann Schleier-Smith Date: Sat, 26 Jul 2025 11:44:50 -0700 Subject: [PATCH 2/2] cleanup --- .../workflows/customer_service_workflow.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/openai_agents/customer_service/workflows/customer_service_workflow.py b/openai_agents/customer_service/workflows/customer_service_workflow.py index a5c20f61..0157d050 100644 --- a/openai_agents/customer_service/workflows/customer_service_workflow.py +++ b/openai_agents/customer_service/workflows/customer_service_workflow.py @@ -55,33 +55,21 @@ def __init__( self.input_items = ( customer_service_state.input_items if customer_service_state else [] ) - self.continue_as_new_suggested = False @workflow.run async def run( self, customer_service_state: CustomerServiceWorkflowState | None = None ): await workflow.wait_condition( - # lambda: workflow.info().is_continue_as_new_suggested() - # and - lambda: self.continue_as_new_suggested + lambda: workflow.info().is_continue_as_new_suggested() and workflow.all_handlers_finished() ) - # Convert input_items to plain dictionaries for serialization - serializable_input_items = [] - for item in self.input_items: - if hasattr(item, "model_dump") and callable(getattr(item, "model_dump")): - # Convert Pydantic objects to dictionaries - serializable_input_items.append(item.model_dump()) # type: ignore - else: - # Already a plain Python object - serializable_input_items.append(item) workflow.continue_as_new( CustomerServiceWorkflowState( printed_history=self.printed_history, current_agent_name=self.current_agent.name, context=self.context, - input_items=serializable_input_items, + input_items=self.input_items, ) ) @@ -130,7 +118,6 @@ async def process_user_message(self, input: ProcessUserMessageInput) -> list[str self.current_agent = result.last_agent workflow.set_current_details("\n\n".join(self.printed_history)) - self.continue_as_new_suggested = True return self.printed_history[length:] @process_user_message.validator