Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
ThreadItemAddedEvent,
ThreadItemDoneEvent,
ThreadItemRemovedEvent,
ThreadItemUpdated,
ThreadItemUpdatedEvent,
ThreadMetadata,
ThreadStreamEvent,
URLSource,
Expand Down Expand Up @@ -168,7 +168,7 @@ async def update_workflow_task(self, task: Task, task_index: int) -> None:
# ensure reference is updated in case task is a copy
self.workflow_item.workflow.tasks[task_index] = task
await self.stream(
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id=self.workflow_item.id,
update=WorkflowTaskUpdated(
task=task,
Expand All @@ -191,7 +191,7 @@ async def add_workflow_task(self, task: Task) -> None:
await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
else:
await self.stream(
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id=self.workflow_item.id,
update=WorkflowTaskAdded(
task=task,
Expand Down Expand Up @@ -448,23 +448,23 @@ def end_workflow(item: WorkflowItem):
if event.part.type == "reasoning_text":
continue
content = _convert_content(event.part)
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=event.item_id,
update=AssistantMessageContentPartAdded(
content_index=event.content_index,
content=content,
),
)
elif event.type == "response.output_text.delta":
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=event.item_id,
update=AssistantMessageContentPartTextDelta(
content_index=event.content_index,
delta=event.delta,
),
)
elif event.type == "response.output_text.done":
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=event.item_id,
update=AssistantMessageContentPartDone(
content_index=event.content_index,
Expand All @@ -485,7 +485,7 @@ def end_workflow(item: WorkflowItem):
item_annotation_count[event.item_id][event.content_index] = (
annotation_index + 1
)
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=event.item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=event.content_index,
Expand Down Expand Up @@ -533,7 +533,7 @@ def end_workflow(item: WorkflowItem):
task=ThoughtTask(content=event.delta),
)
ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=ctx.workflow_item.id,
update=WorkflowTaskAdded(
task=streaming_thought.task,
Expand All @@ -547,7 +547,7 @@ def end_workflow(item: WorkflowItem):
and event.summary_index == streaming_thought.index
):
streaming_thought.task.content += event.delta
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=ctx.workflow_item.id,
update=WorkflowTaskUpdated(
task=streaming_thought.task,
Expand Down Expand Up @@ -578,7 +578,7 @@ def end_workflow(item: WorkflowItem):
task=task,
task_index=ctx.workflow_item.workflow.tasks.index(task),
)
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=ctx.workflow_item.id,
update=update,
)
Expand Down
4 changes: 2 additions & 2 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
ThreadItemDoneEvent,
ThreadItemRemovedEvent,
ThreadItemReplacedEvent,
ThreadItemUpdated,
ThreadItemUpdatedEvent,
ThreadMetadata,
ThreadsAddClientToolOutputReq,
ThreadsAddUserMessageReq,
Expand Down Expand Up @@ -209,7 +209,7 @@ async def stream_widget(
try:
new_state = await widget.__anext__()
for update in diff_widget(last_state, new_state):
yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=item_id,
update=update,
)
Expand Down
6 changes: 5 additions & 1 deletion chatkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,18 @@ class ThreadItemAddedEvent(BaseModel):
item: ThreadItem


class ThreadItemUpdated(BaseModel):
class ThreadItemUpdatedEvent(BaseModel):
"""Event describing an update to an existing thread item."""

type: Literal["thread.item.updated"] = "thread.item.updated"
item_id: str
update: ThreadItemUpdate


# Type alias for backwards compatibility
ThreadItemUpdated = ThreadItemUpdatedEvent


class ThreadItemDoneEvent(BaseModel):
"""Event emitted when a thread item is marked complete."""

Expand Down
38 changes: 19 additions & 19 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
Thread,
ThreadItemAddedEvent,
ThreadItemDoneEvent,
ThreadItemUpdated,
ThreadItemUpdatedEvent,
ThreadStreamEvent,
URLSource,
UserMessageItem,
Expand Down Expand Up @@ -231,7 +231,7 @@ async def widget_generator():
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="")])

assert isinstance(events[1], ThreadItemUpdated)
assert isinstance(events[1], ThreadItemUpdatedEvent)
assert events[1].update.type == "widget.streaming_text.value_delta"
assert events[1].update.component_id == "text"
assert events[1].update.delta == "Hello, world"
Expand Down Expand Up @@ -271,7 +271,7 @@ async def widget_generator():
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])

assert isinstance(events[1], ThreadItemUpdated)
assert isinstance(events[1], ThreadItemUpdatedEvent)
assert events[1].update.type == "widget.root.updated"
assert events[1].update.widget == Card(
children=[Text(key="other text", value="World!", streaming=False)]
Expand Down Expand Up @@ -788,7 +788,7 @@ async def test_stream_agent_response_maps_events():
sequence_number=0,
),
),
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id="123",
update=AssistantMessageContentPartTextDelta(
content_index=0,
Expand All @@ -812,7 +812,7 @@ async def test_stream_agent_response_maps_events():
sequence_number=1,
),
),
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id="123",
update=AssistantMessageContentPartAdded(
content_index=1,
Expand All @@ -833,7 +833,7 @@ async def test_stream_agent_response_maps_events():
sequence_number=2,
),
),
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id="123",
update=AssistantMessageContentPartDone(
content_index=0,
Expand Down Expand Up @@ -862,7 +862,7 @@ async def test_stream_agent_response_maps_events():
sequence_number=3,
),
),
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id="123",
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
Expand Down Expand Up @@ -949,7 +949,7 @@ def add_annotation_event(annotation, sequence_number):

events = await all_events(stream_agent_response(context, result))
assert events == [
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id=item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
Expand All @@ -960,7 +960,7 @@ def add_annotation_event(annotation, sequence_number):
),
),
),
ThreadItemUpdated(
ThreadItemUpdatedEvent(
item_id=item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
Expand Down Expand Up @@ -1297,8 +1297,8 @@ async def test_workflow_streams_first_thought():
event = await anext(stream)
assert context.workflow_item is not None
assert len(context.workflow_item.workflow.tasks) == 1
assert isinstance(event, ThreadItemUpdated)
assert event == ThreadItemUpdated(
assert isinstance(event, ThreadItemUpdatedEvent)
assert event == ThreadItemUpdatedEvent(
item_id=context.workflow_item.id,
update=WorkflowTaskAdded(
task=ThoughtTask(content="Think"),
Expand All @@ -1310,8 +1310,8 @@ async def test_workflow_streams_first_thought():
event = await anext(stream)
assert context.workflow_item is not None
assert len(context.workflow_item.workflow.tasks) == 1
assert isinstance(event, ThreadItemUpdated)
assert event == ThreadItemUpdated(
assert isinstance(event, ThreadItemUpdatedEvent)
assert event == ThreadItemUpdatedEvent(
item_id=context.workflow_item.id,
update=WorkflowTaskUpdated(
task=ThoughtTask(content="Thinking 1"),
Expand All @@ -1323,8 +1323,8 @@ async def test_workflow_streams_first_thought():
event = await anext(stream)
assert context.workflow_item is not None
assert len(context.workflow_item.workflow.tasks) == 1
assert isinstance(event, ThreadItemUpdated)
assert event == ThreadItemUpdated(
assert isinstance(event, ThreadItemUpdatedEvent)
assert event == ThreadItemUpdatedEvent(
item_id=context.workflow_item.id,
update=WorkflowTaskUpdated(
task=ThoughtTask(content="Thinking 1"),
Expand All @@ -1336,8 +1336,8 @@ async def test_workflow_streams_first_thought():
event = await anext(stream)
assert context.workflow_item is not None
assert len(context.workflow_item.workflow.tasks) == 2
assert isinstance(event, ThreadItemUpdated)
assert event == ThreadItemUpdated(
assert isinstance(event, ThreadItemUpdatedEvent)
assert event == ThreadItemUpdatedEvent(
item_id=context.workflow_item.id,
update=WorkflowTaskAdded(
task=ThoughtTask(content="Thinking 2"),
Expand Down Expand Up @@ -1420,8 +1420,8 @@ async def test_workflow_ends_on_message():
event = await anext(stream)
assert context.workflow_item is not None
assert len(context.workflow_item.workflow.tasks) == 1
assert isinstance(event, ThreadItemUpdated)
assert event == ThreadItemUpdated(
assert isinstance(event, ThreadItemUpdatedEvent)
assert event == ThreadItemUpdatedEvent(
item_id=context.workflow_item.id,
update=WorkflowTaskAdded(
task=ThoughtTask(content="Thinking 1"),
Expand Down
14 changes: 7 additions & 7 deletions tests/test_chatkit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
ThreadItemDoneEvent,
ThreadItemRemovedEvent,
ThreadItemReplacedEvent,
ThreadItemUpdated,
ThreadItemUpdatedEvent,
ThreadListParams,
ThreadMetadata,
ThreadRetryAfterItemParams,
Expand Down Expand Up @@ -756,7 +756,7 @@ async def action(
actions.append((action, sender))
assert sender

yield ThreadItemUpdated(
yield ThreadItemUpdatedEvent(
item_id=sender.id,
update=WidgetRootUpdated(
widget=Card(
Expand Down Expand Up @@ -807,7 +807,7 @@ async def action(

assert len(events) == 1
assert events[0].type == "thread.item.updated"
assert isinstance(events[0], ThreadItemUpdated)
assert isinstance(events[0], ThreadItemUpdatedEvent)
assert events[0].update.type == "widget.root.updated"
assert events[0].update.widget == Card(children=[Text(value="Email sent!")])

Expand Down Expand Up @@ -1090,17 +1090,17 @@ async def widget_generator():
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="")])

assert isinstance(events[1], ThreadItemUpdated)
assert isinstance(events[1], ThreadItemUpdatedEvent)
assert events[1].update.type == "widget.streaming_text.value_delta"
assert events[1].update.component_id == "text"
assert events[1].update.delta == "Hel"

assert isinstance(events[2], ThreadItemUpdated)
assert isinstance(events[2], ThreadItemUpdatedEvent)
assert events[2].update.type == "widget.streaming_text.value_delta"
assert events[2].update.component_id == "text"
assert events[2].update.delta == "lo,"

assert isinstance(events[3], ThreadItemUpdated)
assert isinstance(events[3], ThreadItemUpdatedEvent)
assert events[3].update.type == "widget.streaming_text.value_delta"
assert events[3].update.component_id == "text"
assert events[3].update.delta == " world"
Expand Down Expand Up @@ -1128,7 +1128,7 @@ async def widget_generator():
assert isinstance(events[0].item, WidgetItem)
assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])

assert isinstance(events[1], ThreadItemUpdated)
assert isinstance(events[1], ThreadItemUpdatedEvent)
assert events[1].update.type == "widget.root.updated"
assert events[1].update.widget == Card(children=[Text(id="text", value="World")])

Expand Down