diff --git a/chatkit/actions.py b/chatkit/actions.py index 689a8f1..9357c95 100644 --- a/chatkit/actions.py +++ b/chatkit/actions.py @@ -24,7 +24,7 @@ class ActionConfig(BaseModel): class Action(BaseModel, Generic[TType, TPayload]): type: TType = Field(default=TType, frozen=True) # pyright: ignore - payload: TPayload + payload: TPayload | None = None @classmethod def create( diff --git a/tests/helpers/mock_widget.py b/tests/helpers/mock_widget.py index 2dde244..bd0dccd 100644 --- a/tests/helpers/mock_widget.py +++ b/tests/helpers/mock_widget.py @@ -615,6 +615,7 @@ async def handle_action( generate: Callable[[], AsyncIterator[ThreadStreamEvent]], save: Callable[[], AsyncIterator[ThreadStreamEvent]], ) -> AsyncIterator[ActionOutput]: + assert action.payload is not None, "Action payload is required" if action.type == "sample.show_widget": next_state = Index( selected=action.payload.widget,