11import logging
22import uuid
3+ from typing import Any , Callable , Dict , List , Literal , Optional , TypeVar , Union , cast
34
5+ import httpx
46from typing_extensions import deprecated
5- from typing import (
6- Any ,
7- Callable ,
8- Dict ,
9- List ,
10- Literal ,
11- Optional ,
12- TypeVar ,
13- Union ,
14- cast ,
15- )
167
178from literalai .api .base import BaseLiteralAPI , prepare_variables
18-
199from literalai .api .helpers .attachment_helpers import (
2010 AttachmentUpload ,
2111 create_attachment_helper ,
9181 DatasetExperimentItem ,
9282)
9383from literalai .evaluation .dataset_item import DatasetItem
84+ from literalai .my_types import PaginatedResponse , User
9485from literalai .observability .filter import (
9586 generations_filters ,
9687 generations_order_by ,
10293 threads_order_by ,
10394 users_filters ,
10495)
105- from literalai .observability .thread import Thread
106- from literalai .prompt_engineering .prompt import Prompt , ProviderSettings
107-
108- import httpx
109-
110- from literalai .my_types import PaginatedResponse , User
11196from literalai .observability .generation import (
11297 BaseGeneration ,
11398 ChatGeneration ,
123108 StepDict ,
124109 StepType ,
125110)
111+ from literalai .observability .thread import Thread
112+ from literalai .prompt_engineering .prompt import Prompt , ProviderSettings
126113
127114logger = logging .getLogger (__name__ )
128115
@@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI):
141128 R = TypeVar ("R" )
142129
143130 async def make_gql_call (
144- self , description : str , query : str , variables : Dict [str , Any ], timeout : Optional [int ] = 10
131+ self ,
132+ description : str ,
133+ query : str ,
134+ variables : Dict [str , Any ],
135+ timeout : Optional [int ] = 10 ,
145136 ) -> Dict :
146137 def raise_error (error ):
147138 logger .error (f"Failed to { description } : { error } " )
@@ -166,8 +157,7 @@ def raise_error(error):
166157 json = response .json ()
167158 except ValueError as e :
168159 raise_error (
169- f"""Failed to parse JSON response: {
170- e } , content: { response .content !r} """
160+ f"Failed to parse JSON response: { e } , content: { response .content !r} "
171161 )
172162
173163 if json .get ("errors" ):
@@ -178,8 +168,7 @@ def raise_error(error):
178168 for value in json ["data" ].values ():
179169 if value and value .get ("ok" ) is False :
180170 raise_error (
181- f"""Failed to { description } : {
182- value .get ('message' )} """
171+ f"""Failed to { description } : { value .get ("message" )} """
183172 )
184173 return json
185174
@@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
203192 return response .json ()
204193 except ValueError as e :
205194 raise ValueError (
206- f"""Failed to parse JSON response: {
207- e } , content: { response .content !r} """
195+ f"Failed to parse JSON response: { e } , content: { response .content !r} "
208196 )
197+
209198 async def gql_helper (
210199 self ,
211200 query : str ,
@@ -235,7 +224,9 @@ async def get_user(
235224 ) -> "User" :
236225 return await self .gql_helper (* get_user_helper (id , identifier ))
237226
238- async def create_user (self , identifier : str , metadata : Optional [Dict ] = None ) -> "User" :
227+ async def create_user (
228+ self , identifier : str , metadata : Optional [Dict ] = None
229+ ) -> "User" :
239230 return await self .gql_helper (* create_user_helper (identifier , metadata ))
240231
241232 async def update_user (
@@ -245,7 +236,7 @@ async def update_user(
245236
246237 async def delete_user (self , id : str ) -> Dict :
247238 return await self .gql_helper (* delete_user_helper (id ))
248-
239+
249240 async def get_or_create_user (
250241 self , identifier : str , metadata : Optional [Dict ] = None
251242 ) -> "User" :
@@ -273,7 +264,7 @@ async def get_threads(
273264 first , after , before , filters , order_by , step_types_to_keep
274265 )
275266 )
276-
267+
277268 async def list_threads (
278269 self ,
279270 first : Optional [int ] = None ,
@@ -491,7 +482,7 @@ async def create_attachment(
491482 thread_id = active_thread .id
492483
493484 if not step_id :
494- if active_steps := active_steps_var .get ([] ):
485+ if active_steps := active_steps_var .get ():
495486 step_id = active_steps [- 1 ].id
496487 else :
497488 raise Exception ("No step_id provided and no active step found." )
@@ -532,7 +523,9 @@ async def create_attachment(
532523 response = await self .make_gql_call (description , query , variables )
533524 return process_response (response )
534525
535- async def update_attachment (self , id : str , update_params : AttachmentUpload ) -> "Attachment" :
526+ async def update_attachment (
527+ self , id : str , update_params : AttachmentUpload
528+ ) -> "Attachment" :
536529 return await self .gql_helper (* update_attachment_helper (id , update_params ))
537530
538531 async def get_attachment (self , id : str ) -> Optional ["Attachment" ]:
@@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict:
545538 # Step APIs #
546539 ##################################################################################
547540
548-
549541 async def create_step (
550542 self ,
551543 thread_id : Optional [str ] = None ,
@@ -646,7 +638,7 @@ async def get_generations(
646638 return await self .gql_helper (
647639 * get_generations_helper (first , after , before , filters , order_by )
648640 )
649-
641+
650642 async def create_generation (
651643 self , generation : Union ["ChatGeneration" , "CompletionGeneration" ]
652644 ) -> Union ["ChatGeneration" , "CompletionGeneration" ]:
@@ -667,8 +659,10 @@ async def create_dataset(
667659 return await self .gql_helper (
668660 * create_dataset_helper (sync_api , name , description , metadata , type )
669661 )
670-
671- async def get_dataset (self , id : Optional [str ] = None , name : Optional [str ] = None ) -> "Dataset" :
662+
663+ async def get_dataset (
664+ self , id : Optional [str ] = None , name : Optional [str ] = None
665+ ) -> "Dataset" :
672666 sync_api = LiteralAPI (self .api_key , self .url )
673667 subpath , _ , variables , process_response = get_dataset_helper (
674668 sync_api , id = id , name = name
@@ -738,7 +732,7 @@ async def create_experiment_item(
738732 result .scores = await self .create_scores (experiment_item .scores )
739733
740734 return result
741-
735+
742736 ##################################################################################
743737 # DatasetItem APIs #
744738 ##################################################################################
@@ -753,7 +747,7 @@ async def create_dataset_item(
753747 return await self .gql_helper (
754748 * create_dataset_item_helper (dataset_id , input , expected_output , metadata )
755749 )
756-
750+
757751 async def get_dataset_item (self , id : str ) -> "DatasetItem" :
758752 return await self .gql_helper (* get_dataset_item_helper (id ))
759753
@@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage(
784778 return await self .gql_helper (* create_prompt_lineage_helper (name , description ))
785779
786780 @deprecated ('Please use "get_or_create_prompt_lineage" instead.' )
787- async def create_prompt_lineage (self , name : str , description : Optional [str ] = None ) -> Dict :
781+ async def create_prompt_lineage (
782+ self , name : str , description : Optional [str ] = None
783+ ) -> Dict :
788784 return await self .get_or_create_prompt_lineage (name , description )
789785
790786 async def get_or_create_prompt (
@@ -838,7 +834,14 @@ async def get_prompt(
838834 raise ValueError ("At least the `id` or the `name` must be provided." )
839835
840836 sync_api = LiteralAPI (self .api_key , self .url )
841- get_prompt_query , description , variables , process_response , timeout , cached_prompt = get_prompt_helper (
837+ (
838+ get_prompt_query ,
839+ description ,
840+ variables ,
841+ process_response ,
842+ timeout ,
843+ cached_prompt ,
844+ ) = get_prompt_helper (
842845 api = sync_api , id = id , name = name , version = version , cache = self .cache
843846 )
844847
0 commit comments