Skip to content

Commit 796894e

Browse files
committed
feat: add step start and finish events
1 parent 267ca19 commit 796894e

File tree

28 files changed

+1292
-18
lines changed

28 files changed

+1292
-18
lines changed

src/Enums/StreamEventType.php

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ enum StreamEventType: string
1919
case Citation = 'citation';
2020
case Error = 'error';
2121
case StreamEnd = 'stream_end';
22+
case StepStart = 'step_start';
23+
case StepFinish = 'step_finish';
2224
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Prism\Prism\Events\Broadcasting;
6+
7+
class StepFinishBroadcast extends StreamEventBroadcast {}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Prism\Prism\Events\Broadcasting;
6+
7+
class StepStartBroadcast extends StreamEventBroadcast {}

src/Providers/Anthropic/Handlers/Stream.php

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
use Prism\Prism\Streaming\Events\CitationEvent;
1818
use Prism\Prism\Streaming\Events\ErrorEvent;
1919
use Prism\Prism\Streaming\Events\ProviderToolEvent;
20+
use Prism\Prism\Streaming\Events\StepFinishEvent;
21+
use Prism\Prism\Streaming\Events\StepStartEvent;
2022
use Prism\Prism\Streaming\Events\StreamEndEvent;
2123
use Prism\Prism\Streaming\Events\StreamEvent;
2224
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -76,15 +78,21 @@ protected function processStream(Response $response, Request $request, int $dept
7678
$streamEvent = $this->processEvent($event);
7779

7880
if ($streamEvent instanceof Generator) {
79-
yield from $streamEvent;
81+
// Re-yield items to avoid key conflicts when using collect()
82+
foreach ($streamEvent as $item) {
83+
yield $item;
84+
}
8085
} elseif ($streamEvent instanceof StreamEvent) {
8186
yield $streamEvent;
8287
}
8388
}
8489

8590
// Handle tool calls if present
8691
if ($this->state->hasToolCalls()) {
87-
yield from $this->handleToolCalls($request, $depth);
92+
// Re-yield items to avoid key conflicts when using collect()
93+
foreach ($this->handleToolCalls($request, $depth) as $item) {
94+
yield $item;
95+
}
8896
}
8997
}
9098

@@ -109,8 +117,9 @@ protected function processEvent(array $event): StreamEvent|Generator|null
109117

110118
/**
111119
* @param array<string, mixed> $event
120+
* @return Generator<StreamEvent>
112121
*/
113-
protected function handleMessageStart(array $event): StreamStartEvent
122+
protected function handleMessageStart(array $event): Generator
114123
{
115124
$message = $event['message'] ?? [];
116125
$this->state->withMessageId($message['id'] ?? EventID::generate());
@@ -126,12 +135,21 @@ protected function handleMessageStart(array $event): StreamStartEvent
126135
));
127136
}
128137

129-
return new StreamStartEvent(
138+
yield new StreamStartEvent(
130139
id: EventID::generate(),
131140
timestamp: time(),
132141
model: $message['model'] ?? 'unknown',
133142
provider: 'anthropic'
134143
);
144+
145+
if ($this->state->shouldEmitStepStart()) {
146+
$this->state->markStepStarted();
147+
148+
yield new StepStartEvent(
149+
id: EventID::generate(),
150+
timestamp: time()
151+
);
152+
}
135153
}
136154

137155
/**
@@ -228,10 +246,17 @@ protected function handleMessageDelta(array $event): null
228246

229247
/**
230248
* @param array<string, mixed> $event
249+
* @return Generator<StreamEvent>
231250
*/
232-
protected function handleMessageStop(array $event): StreamEndEvent
251+
protected function handleMessageStop(array $event): Generator
233252
{
234-
return new StreamEndEvent(
253+
$this->state->markStepFinished();
254+
yield new StepFinishEvent(
255+
id: EventID::generate(),
256+
timestamp: time()
257+
);
258+
259+
yield new StreamEndEvent(
235260
id: EventID::generate(),
236261
timestamp: time(),
237262
finishReason: FinishReason::Stop, // Default, will be updated by message_delta
@@ -492,6 +517,13 @@ protected function handleToolCalls(Request $request, int $depth): Generator
492517

493518
$request->addMessage(new ToolResultMessage($toolResults));
494519

520+
// Emit step finish after tool calls
521+
$this->state->markStepFinished();
522+
yield new StepFinishEvent(
523+
id: EventID::generate(),
524+
timestamp: time()
525+
);
526+
495527
// Continue streaming if within step limit
496528
$depth++;
497529
if ($depth < $request->maxSteps()) {

src/Providers/DeepSeek/Handlers/Stream.php

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
use Prism\Prism\Providers\DeepSeek\Maps\ToolChoiceMap;
2121
use Prism\Prism\Providers\DeepSeek\Maps\ToolMap;
2222
use Prism\Prism\Streaming\EventID;
23+
use Prism\Prism\Streaming\Events\StepFinishEvent;
24+
use Prism\Prism\Streaming\Events\StepStartEvent;
2325
use Prism\Prism\Streaming\Events\StreamEndEvent;
2426
use Prism\Prism\Streaming\Events\StreamEvent;
2527
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -95,6 +97,15 @@ protected function processStream(Response $response, Request $request, int $dept
9597
);
9698
}
9799

100+
if ($this->state->shouldEmitStepStart()) {
101+
$this->state->markStepStarted();
102+
103+
yield new StepStartEvent(
104+
id: EventID::generate(),
105+
timestamp: time()
106+
);
107+
}
108+
98109
if ($this->hasToolCalls($data)) {
99110
$toolCalls = $this->extractToolCalls($data, $toolCalls);
100111

@@ -200,6 +211,12 @@ protected function processStream(Response $response, Request $request, int $dept
200211

201212
$usage = $this->extractUsage($data);
202213

214+
$this->state->markStepFinished();
215+
yield new StepFinishEvent(
216+
id: EventID::generate(),
217+
timestamp: time()
218+
);
219+
203220
yield new StreamEndEvent(
204221
id: EventID::generate(),
205222
timestamp: time(),
@@ -362,6 +379,12 @@ protected function handleToolCalls(Request $request, string $text, array $toolCa
362379
$request->addMessage(new AssistantMessage($text, $mappedToolCalls));
363380
$request->addMessage(new ToolResultMessage($toolResults));
364381

382+
$this->state->markStepFinished();
383+
yield new StepFinishEvent(
384+
id: EventID::generate(),
385+
timestamp: time()
386+
);
387+
365388
$this->state->resetTextState();
366389
$this->state->withMessageId(EventID::generate());
367390

src/Providers/Gemini/Handlers/Stream.php

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
use Prism\Prism\Providers\Gemini\Maps\ToolChoiceMap;
1818
use Prism\Prism\Providers\Gemini\Maps\ToolMap;
1919
use Prism\Prism\Streaming\EventID;
20+
use Prism\Prism\Streaming\Events\StepFinishEvent;
21+
use Prism\Prism\Streaming\Events\StepStartEvent;
2022
use Prism\Prism\Streaming\Events\StreamEndEvent;
2123
use Prism\Prism\Streaming\Events\StreamEvent;
2224
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -98,6 +100,16 @@ protected function processStream(Response $response, Request $request, int $dept
98100
$this->state->markStreamStarted();
99101
}
100102

103+
// Emit step start event once per step
104+
if ($this->state->shouldEmitStepStart()) {
105+
$this->state->markStepStarted();
106+
107+
yield new StepStartEvent(
108+
id: EventID::generate(),
109+
timestamp: time()
110+
);
111+
}
112+
101113
// Update usage data from each chunk
102114
$this->state->withUsage($this->extractUsage($data, $request));
103115

@@ -208,6 +220,13 @@ protected function processStream(Response $response, Request $request, int $dept
208220
// Extract grounding metadata if available
209221
$groundingMetadata = $this->extractGroundingMetadata($data);
210222

223+
// Emit step finish before stream end
224+
$this->state->markStepFinished();
225+
yield new StepFinishEvent(
226+
id: EventID::generate(),
227+
timestamp: time()
228+
);
229+
211230
// Emit stream end event
212231
yield new StreamEndEvent(
213232
id: EventID::generate(),
@@ -338,6 +357,13 @@ protected function handleToolCalls(
338357
$request->addMessage(new AssistantMessage($this->state->currentText(), $mappedToolCalls));
339358
$request->addMessage(new ToolResultMessage($toolResults));
340359

360+
// Emit step finish after tool calls
361+
$this->state->markStepFinished();
362+
yield new StepFinishEvent(
363+
id: EventID::generate(),
364+
timestamp: time()
365+
);
366+
341367
$depth++;
342368
if ($depth < $request->maxSteps()) {
343369
$this->state->reset();

src/Providers/Groq/Handlers/Stream.php

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
use Prism\Prism\Providers\Groq\Maps\ToolMap;
2222
use Prism\Prism\Streaming\EventID;
2323
use Prism\Prism\Streaming\Events\ErrorEvent;
24+
use Prism\Prism\Streaming\Events\StepFinishEvent;
25+
use Prism\Prism\Streaming\Events\StepStartEvent;
2426
use Prism\Prism\Streaming\Events\StreamEndEvent;
2527
use Prism\Prism\Streaming\Events\StreamEvent;
2628
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -95,6 +97,16 @@ protected function processStream(Response $response, Request $request, int $dept
9597
);
9698
}
9799

100+
// Emit step start event once per step
101+
if ($this->state->shouldEmitStepStart()) {
102+
$this->state->markStepStarted();
103+
104+
yield new StepStartEvent(
105+
id: EventID::generate(),
106+
timestamp: time()
107+
);
108+
}
109+
98110
if ($this->hasError($data)) {
99111
yield from $this->handleErrors($data, $request);
100112

@@ -162,6 +174,13 @@ protected function processStream(Response $response, Request $request, int $dept
162174
// Extract usage information from the final chunk
163175
$usage = $this->extractUsage($data);
164176

177+
// Emit step finish before stream end
178+
$this->state->markStepFinished();
179+
yield new StepFinishEvent(
180+
id: EventID::generate(),
181+
timestamp: time()
182+
);
183+
165184
yield new StreamEndEvent(
166185
id: EventID::generate(),
167186
timestamp: time(),
@@ -260,6 +279,13 @@ protected function handleToolCalls(
260279
$request->addMessage(new AssistantMessage($text, $mappedToolCalls));
261280
$request->addMessage(new ToolResultMessage($toolResults));
262281

282+
// Emit step finish after tool calls
283+
$this->state->markStepFinished();
284+
yield new StepFinishEvent(
285+
id: EventID::generate(),
286+
timestamp: time()
287+
);
288+
263289
// Reset text state for next response
264290
$this->state->resetTextState();
265291
$this->state->withMessageId(EventID::generate());

src/Providers/Mistral/Handlers/Stream.php

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
use Prism\Prism\Providers\Mistral\Maps\ToolChoiceMap;
2121
use Prism\Prism\Providers\Mistral\Maps\ToolMap;
2222
use Prism\Prism\Streaming\EventID;
23+
use Prism\Prism\Streaming\Events\StepFinishEvent;
24+
use Prism\Prism\Streaming\Events\StepStartEvent;
2325
use Prism\Prism\Streaming\Events\StreamEndEvent;
2426
use Prism\Prism\Streaming\Events\StreamEvent;
2527
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -94,6 +96,15 @@ protected function processStream(Response $response, Request $request, int $dept
9496
);
9597
}
9698

99+
if ($this->state->shouldEmitStepStart()) {
100+
$this->state->markStepStarted();
101+
102+
yield new StepStartEvent(
103+
id: EventID::generate(),
104+
timestamp: time()
105+
);
106+
}
107+
97108
if ($this->hasToolCalls($data)) {
98109
$toolCalls = $this->extractToolCalls($data, $toolCalls);
99110

@@ -165,6 +176,12 @@ protected function processStream(Response $response, Request $request, int $dept
165176

166177
$usage = $this->extractUsage($data);
167178

179+
$this->state->markStepFinished();
180+
yield new StepFinishEvent(
181+
id: EventID::generate(),
182+
timestamp: time()
183+
);
184+
168185
yield new StreamEndEvent(
169186
id: EventID::generate(),
170187
timestamp: time(),
@@ -254,6 +271,12 @@ protected function handleToolCalls(
254271
$request->addMessage(new AssistantMessage($text, $mappedToolCalls));
255272
$request->addMessage(new ToolResultMessage($toolResults));
256273

274+
$this->state->markStepFinished();
275+
yield new StepFinishEvent(
276+
id: EventID::generate(),
277+
timestamp: time()
278+
);
279+
257280
$this->state->resetTextState();
258281
$this->state->withMessageId(EventID::generate());
259282

src/Providers/Ollama/Handlers/Stream.php

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
use Prism\Prism\Providers\Ollama\Maps\ToolMap;
1818
use Prism\Prism\Providers\Ollama\ValueObjects\OllamaStreamState;
1919
use Prism\Prism\Streaming\EventID;
20+
use Prism\Prism\Streaming\Events\StepFinishEvent;
21+
use Prism\Prism\Streaming\Events\StepStartEvent;
2022
use Prism\Prism\Streaming\Events\StreamEndEvent;
2123
use Prism\Prism\Streaming\Events\StreamEvent;
2224
use Prism\Prism\Streaming\Events\StreamStartEvent;
@@ -87,6 +89,16 @@ protected function processStream(Response $response, Request $request, int $dept
8789
$this->state->markStreamStarted()->withMessageId(EventID::generate());
8890
}
8991

92+
// Emit step start event once per step
93+
if ($this->state->shouldEmitStepStart()) {
94+
$this->state->markStepStarted();
95+
96+
yield new StepStartEvent(
97+
id: EventID::generate(),
98+
timestamp: time()
99+
);
100+
}
101+
90102
// Accumulate token counts
91103
$this->state->addPromptTokens((int) data_get($data, 'prompt_eval_count', 0));
92104
$this->state->addCompletionTokens((int) data_get($data, 'eval_count', 0));
@@ -182,6 +194,13 @@ protected function processStream(Response $response, Request $request, int $dept
182194
);
183195
}
184196

197+
// Emit step finish before stream end
198+
$this->state->markStepFinished();
199+
yield new StepFinishEvent(
200+
id: EventID::generate(),
201+
timestamp: time()
202+
);
203+
185204
// Emit stream end event with usage
186205
yield new StreamEndEvent(
187206
id: EventID::generate(),
@@ -277,6 +296,13 @@ protected function handleToolCalls(
277296
$request->addMessage(new AssistantMessage($text, $mappedToolCalls));
278297
$request->addMessage(new ToolResultMessage($toolResults));
279298

299+
// Emit step finish after tool calls
300+
$this->state->markStepFinished();
301+
yield new StepFinishEvent(
302+
id: EventID::generate(),
303+
timestamp: time()
304+
);
305+
280306
// Continue streaming if within step limit
281307
$depth++;
282308
if ($depth < $request->maxSteps()) {

0 commit comments

Comments
 (0)