@@ -963,7 +963,6 @@ def execute(self):
963963 self .reporter .warning_messages ("Empty scores." )
964964 return self ._response
965965
966-
967966class DownloadAnnotations (BaseReportableUseCase ):
968967 def __init__ (
969968 self ,
@@ -992,6 +991,8 @@ def __init__(
992991 self ._classes = classes
993992 self ._callback = callback
994993 self ._images = images
994+ self ._big_file_queues = []
995+ self ._small_file_queues = []
995996
996997 def validate_item_names (self ):
997998 if self ._item_names :
@@ -1055,7 +1056,115 @@ def coroutine_wrapper(coroutine):
10551056 loop .close ()
10561057 return count
10571058
1059+ async def _download_big_annotation (self , item , export_path ):
1060+ postfix = self .get_postfix ()
1061+ response = await self ._backend_client .download_big_annotation (
1062+ item = item ,
1063+ team_id = self ._project .team_id ,
1064+ project_id = self ._project .id ,
1065+ folder_id = self ._folder .uuid ,
1066+ reporter = self .reporter ,
1067+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1068+ postfix = postfix ,
1069+ callback = self ._callback ,
1070+ )
1071+
1072+ return
1073+
1074+ async def download_big_annotations (self , queue_idx , export_path ):
1075+ while True :
1076+ cur_queue = self ._big_file_queues [queue_idx ]
1077+ item = await cur_queue .get ()
1078+ cur_queue .task_done ()
1079+ if item :
1080+ await self ._download_big_annotation (item , export_path )
1081+ else :
1082+ cur_queue .put_nowait (None )
1083+ break
1084+
1085+ async def download_small_annotations (self , queue_idx , export_path ):
1086+ max_chunk_size = 50000
1087+
1088+ cur_queue = self ._small_file_queues [queue_idx ]
1089+
1090+ items = []
1091+ i = 0
1092+ item = ""
1093+
1094+ postfix = self .get_postfix ()
1095+ while item is not None :
1096+ item = await cur_queue .get ()
1097+ if not item :
1098+ await self ._backend_client .download_small_annotations (
1099+ team_id = self ._project .team_id ,
1100+ project_id = self ._project .id ,
1101+ folder_id = self ._folder .uuid ,
1102+ items = items ,
1103+ reporter = self .reporter ,
1104+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1105+ postfix = postfix ,
1106+ callback = self ._callback ,
1107+ )
1108+ await cur_queue .put (None )
1109+ break
1110+
1111+ items .append (item )
1112+ if len (items ) == max_chunk_size :
1113+ await self ._backend_client .download_small_annotations (
1114+ team_id = self ._project .team_id ,
1115+ project_id = self ._project .id ,
1116+ folder_id = self ._folder .uuid ,
1117+ items = items ,
1118+ reporter = self .reporter ,
1119+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1120+ postfix = postfix ,
1121+ callback = self ._callback ,
1122+ )
1123+ items = []
1124+
1125+
1126+ async def distribute_to_queues (self , item_names , sm_queue_id , l_queue_id , folder_id ):
1127+
1128+ team_id = self ._project .team_id
1129+ project_id = self ._project .id
1130+
1131+ resp = self ._backend_client .sort_items_by_size (item_names , team_id , project_id , folder_id )
1132+
1133+ for item in resp ['large' ]:
1134+ await self ._big_file_queues [l_queue_id ].put (item )
1135+
1136+ for item in resp ['small' ]:
1137+ await self ._small_file_queues [sm_queue_id ].put (item ['name' ])
1138+
1139+
1140+ await self ._big_file_queues [l_queue_id ].put (None )
1141+ await self ._small_file_queues [sm_queue_id ].put (None )
1142+
1143+ async def run_workers (self , item_names , folder_id , export_path ):
1144+ try :
1145+ self ._big_file_queues .append (asyncio .Queue ())
1146+ self ._small_file_queues .append (asyncio .Queue ())
1147+ small_file_queue_idx = len (self ._small_file_queues ) - 1
1148+ big_file_queue_idx = len (self ._big_file_queues ) - 1
1149+
1150+ res = await asyncio .gather (
1151+ self .distribute_to_queues (item_names , small_file_queue_idx , big_file_queue_idx , folder_id ),
1152+ self .download_big_annotations (big_file_queue_idx , export_path ),
1153+ self .download_big_annotations (big_file_queue_idx , export_path ),
1154+ self .download_big_annotations (big_file_queue_idx , export_path ),
1155+ self .download_small_annotations (small_file_queue_idx , export_path ),
1156+ return_exceptions = True
1157+ )
1158+
1159+ except Exception as e :
1160+ self .reporter .log_error (f"Error { str (e )} " )
1161+
1162+
1163+ def per_folder_execute (self , item_names , folder_id , export_path ):
1164+ asyncio .run (self .run_workers (item_names , folder_id , export_path ))
1165+
10581166 def execute (self ):
1167+
10591168 if self .is_valid ():
10601169 export_path = str (
10611170 self .destination
@@ -1067,6 +1176,7 @@ def execute(self):
10671176 f"Downloading the annotations of the requested items to { export_path } \n This might take a while…"
10681177 )
10691178 self .reporter .start_spinner ()
1179+
10701180 folders = []
10711181 if self ._folder .is_root and self ._recursive :
10721182 folders = self ._folders .get_all (
@@ -1075,6 +1185,7 @@ def execute(self):
10751185 )
10761186 folders .append (self ._folder )
10771187 postfix = self .get_postfix ()
1188+
10781189 import nest_asyncio
10791190 import platform
10801191
@@ -1084,60 +1195,35 @@ def execute(self):
10841195 nest_asyncio .apply ()
10851196
10861197 if not folders :
1087- loop = asyncio .new_event_loop ()
1088- if not self ._item_names :
1089- condition = (
1090- Condition ("team_id" , self ._project .team_id , EQ )
1091- & Condition ("project_id" , self ._project .id , EQ )
1092- & Condition ("folder_id" , self ._folder .uuid , EQ )
1093- )
1094- item_names = [item .name for item in self ._images .get_all (condition )]
1095- else :
1096- item_names = self ._item_names
1097- count = loop .run_until_complete (
1098- self ._backend_client .download_annotations (
1099- team_id = self ._project .team_id ,
1100- project_id = self ._project .id ,
1101- folder_id = self ._folder .uuid ,
1102- items = item_names ,
1103- reporter = self .reporter ,
1104- download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1105- postfix = postfix ,
1106- callback = self ._callback ,
1107- )
1108- )
1109- else :
1110- with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
1111- coroutines = []
1112- for folder in folders :
1113- if not self ._item_names :
1114- condition = (
1115- Condition ("team_id" , self ._project .team_id , EQ )
1116- & Condition ("project_id" , self ._project .id , EQ )
1117- & Condition ("folder_id" , folder .uuid , EQ )
1118- )
1119- item_names = [
1120- item .name for item in self ._images .get_all (condition )
1121- ]
1122- else :
1123- item_names = self ._item_names
1124- coroutines .append (
1125- self ._backend_client .download_annotations (
1126- team_id = self ._project .team_id ,
1127- project_id = self ._project .id ,
1128- folder_id = folder .uuid ,
1129- items = item_names ,
1130- reporter = self .reporter ,
1131- download_path = f"{ export_path } { '/' + folder .name if not folder .is_root else '' } " , # noqa
1132- postfix = postfix ,
1133- callback = self ._callback ,
1134- )
1198+ folders .append (self ._folder )
1199+
1200+
1201+ with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
1202+ futures = []
1203+
1204+ for folder in folders :
1205+ if not self ._item_names :
1206+ condition = (
1207+ Condition ("team_id" , self ._project .team_id , EQ )
1208+ & Condition ("project_id" , self ._project .id , EQ )
1209+ & Condition ("folder_id" , folder .uuid , EQ )
11351210 )
1136- count = sum (
1137- [i for i in executor .map (self .coroutine_wrapper , coroutines )]
1138- )
1211+ item_names = [item .name for item in self ._images .get_all (condition )]
1212+ else :
1213+ item_names = self ._item_names
1214+
1215+ new_export_path = export_path
1216+ if folder .name != 'root' :
1217+ new_export_path += f'/{ folder .name } '
1218+ executor .submit (self .per_folder_execute , item_names , folder .uuid , new_export_path )
1219+
1220+
1221+ for future in concurrent .futures .as_completed (futures ):
1222+ print ("asd" )
1223+
11391224
11401225 self .reporter .stop_spinner ()
1226+ count = self .get_items_count (export_path )
11411227 self .reporter .log_info (f"Downloaded annotations for { count } items." )
11421228 self .download_annotation_classes (export_path )
11431229 self ._response .data = os .path .abspath (export_path )
0 commit comments