11import time
2- from typing import TYPE_CHECKING , AsyncGenerator , Dict , Union
2+ from typing import TYPE_CHECKING , AsyncGenerator , Dict , Generator , Union
33
44from literalai .instrumentation import MISTRALAI_PROVIDER
55from literalai .requirements import check_all_requirements
66
77if TYPE_CHECKING :
88 from literalai .client import LiteralClient
99
10- from types import GeneratorType
11-
1210from literalai .context import active_steps_var , active_thread_var
1311from literalai .helper import ensure_values_serializable
14- from literalai .observability .generation import GenerationMessage , CompletionGeneration , ChatGeneration , GenerationType
12+ from literalai .observability .generation import (
13+ ChatGeneration ,
14+ CompletionGeneration ,
15+ GenerationMessage ,
16+ GenerationType ,
17+ )
1518from literalai .wrappers import AfterContext , BeforeContext , wrap_all
1619
17- REQUIREMENTS = ["mistralai>=0.2 .0" ]
20+ REQUIREMENTS = ["mistralai>=1.0 .0" ]
1821
1922APIS_TO_WRAP = [
2023 {
21- "module" : "mistralai.client " ,
22- "object" : "MistralClient " ,
23- "method" : "chat " ,
24+ "module" : "mistralai.chat " ,
25+ "object" : "Chat " ,
26+ "method" : "complete " ,
2427 "metadata" : {
2528 "type" : GenerationType .CHAT ,
2629 },
2730 "async" : False ,
2831 },
2932 {
30- "module" : "mistralai.client " ,
31- "object" : "MistralClient " ,
32- "method" : "chat_stream " ,
33+ "module" : "mistralai.chat " ,
34+ "object" : "Chat " ,
35+ "method" : "stream " ,
3336 "metadata" : {
3437 "type" : GenerationType .CHAT ,
3538 },
3639 "async" : False ,
3740 },
3841 {
39- "module" : "mistralai.async_client " ,
40- "object" : "MistralAsyncClient " ,
41- "method" : "chat " ,
42+ "module" : "mistralai.chat " ,
43+ "object" : "Chat " ,
44+ "method" : "complete_async " ,
4245 "metadata" : {
4346 "type" : GenerationType .CHAT ,
4447 },
4548 "async" : True ,
4649 },
4750 {
48- "module" : "mistralai.async_client " ,
49- "object" : "MistralAsyncClient " ,
50- "method" : "chat_stream " ,
51+ "module" : "mistralai.chat " ,
52+ "object" : "Chat " ,
53+ "method" : "stream_async " ,
5154 "metadata" : {
5255 "type" : GenerationType .CHAT ,
5356 },
5457 "async" : True ,
5558 },
5659 {
57- "module" : "mistralai.client " ,
58- "object" : "MistralClient " ,
59- "method" : "completion " ,
60+ "module" : "mistralai.fim " ,
61+ "object" : "Fim " ,
62+ "method" : "complete " ,
6063 "metadata" : {
6164 "type" : GenerationType .COMPLETION ,
6265 },
6366 "async" : False ,
6467 },
6568 {
66- "module" : "mistralai.client " ,
67- "object" : "MistralClient " ,
68- "method" : "completion_stream " ,
69+ "module" : "mistralai.fim " ,
70+ "object" : "Fim " ,
71+ "method" : "stream " ,
6972 "metadata" : {
7073 "type" : GenerationType .COMPLETION ,
7174 },
7275 "async" : False ,
7376 },
7477 {
75- "module" : "mistralai.async_client " ,
76- "object" : "MistralAsyncClient " ,
77- "method" : "completion " ,
78+ "module" : "mistralai.fim " ,
79+ "object" : "Fim " ,
80+ "method" : "complete_async " ,
7881 "metadata" : {
7982 "type" : GenerationType .COMPLETION ,
8083 },
8184 "async" : True ,
8285 },
8386 {
84- "module" : "mistralai.async_client " ,
85- "object" : "MistralAsyncClient " ,
86- "method" : "completion_stream " ,
87+ "module" : "mistralai.fim " ,
88+ "object" : "Fim " ,
89+ "method" : "stream_async " ,
8790 "metadata" : {
8891 "type" : GenerationType .COMPLETION ,
8992 },
@@ -239,13 +242,13 @@ async def before(context: BeforeContext, *args, **kwargs):
239242
240243 return before
241244
242- from mistralai . models . chat_completion import DeltaMessage
245+ from mistralai import DeltaMessage
243246
244247 def process_delta (new_delta : DeltaMessage , message_completion : GenerationMessage ):
245248 if new_delta .tool_calls :
246249 if "tool_calls" not in message_completion :
247250 message_completion ["tool_calls" ] = []
248- delta_tool_call = new_delta .tool_calls [0 ]
251+ delta_tool_call = new_delta .tool_calls [0 ] # type: ignore
249252 delta_function = delta_tool_call .function
250253 if not delta_function :
251254 return False
@@ -273,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
273276 else :
274277 return False
275278
279+ from mistralai import models
280+
276281 def streaming_response (
277282 generation : Union [ChatGeneration , CompletionGeneration ],
278- result : GeneratorType ,
283+ result : Generator [ models . CompletionEvent , None , None ] ,
279284 context : AfterContext ,
280285 ):
281286 completion = ""
@@ -286,8 +291,8 @@ def streaming_response(
286291 token_count = 0
287292 for chunk in result :
288293 if generation and isinstance (generation , ChatGeneration ):
289- if len (chunk .choices ) > 0 :
290- ok = process_delta (chunk .choices [0 ].delta , message_completion )
294+ if len (chunk .data . choices ) > 0 :
295+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
291296 if not ok :
292297 yield chunk
293298 continue
@@ -298,22 +303,22 @@ def streaming_response(
298303 token_count += 1
299304 elif generation and isinstance (generation , CompletionGeneration ):
300305 if (
301- len (chunk .choices ) > 0
302- and chunk .choices [0 ].message .content is not None
306+ len (chunk .data . choices ) > 0
307+ and chunk .data . choices [0 ].delta .content is not None
303308 ):
304309 if generation .tt_first_token is None :
305310 generation .tt_first_token = (
306311 time .time () - context ["start" ]
307312 ) * 1000
308313 token_count += 1
309- completion += chunk .choices [0 ].message .content
314+ completion += chunk .data . choices [0 ].delta .content
310315
311316 if (
312317 generation
313318 and getattr (chunk , "model" , None )
314- and generation .model != chunk .model
319+ and generation .model != chunk .data . model
315320 ):
316- generation .model = chunk .model
321+ generation .model = chunk .data . model
317322
318323 yield chunk
319324
@@ -358,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
358363 generation .model = model
359364 if generation .settings :
360365 generation .settings ["model" ] = model
361- if isinstance (result , GeneratorType ):
366+ if isinstance (result , Generator ):
362367 return streaming_response (generation , result , context )
363368 else :
364369 generation .duration = time .time () - context ["start" ]
@@ -387,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):
387392
388393 async def async_streaming_response (
389394 generation : Union [ChatGeneration , CompletionGeneration ],
390- result : AsyncGenerator ,
395+ result : AsyncGenerator [ models . CompletionEvent , None ] ,
391396 context : AfterContext ,
392397 ):
393398 completion = ""
@@ -398,8 +403,8 @@ async def async_streaming_response(
398403 token_count = 0
399404 async for chunk in result :
400405 if generation and isinstance (generation , ChatGeneration ):
401- if len (chunk .choices ) > 0 :
402- ok = process_delta (chunk .choices [0 ].delta , message_completion )
406+ if len (chunk .data . choices ) > 0 :
407+ ok = process_delta (chunk .data . choices [0 ].delta , message_completion )
403408 if not ok :
404409 yield chunk
405410 continue
@@ -410,22 +415,22 @@ async def async_streaming_response(
410415 token_count += 1
411416 elif generation and isinstance (generation , CompletionGeneration ):
412417 if (
413- len (chunk .choices ) > 0
414- and chunk .choices [0 ].message . content is not None
418+ len (chunk .data . choices ) > 0
419+ and chunk .data . choices [0 ].delta is not None
415420 ):
416421 if generation .tt_first_token is None :
417422 generation .tt_first_token = (
418423 time .time () - context ["start" ]
419424 ) * 1000
420425 token_count += 1
421- completion += chunk .choices [0 ].message .content
426+ completion += chunk .data . choices [0 ].delta .content or ""
422427
423428 if (
424429 generation
425430 and getattr (chunk , "model" , None )
426- and generation .model != chunk .model
431+ and generation .model != chunk .data . model
427432 ):
428- generation .model = chunk .model
433+ generation .model = chunk .data . model
429434
430435 yield chunk
431436
0 commit comments