@@ -963,7 +963,6 @@ def execute(self):
963963 self .reporter .warning_messages ("Empty scores." )
964964 return self ._response
965965
966-
967966class DownloadAnnotations (BaseReportableUseCase ):
968967 def __init__ (
969968 self ,
@@ -992,6 +991,8 @@ def __init__(
992991 self ._classes = classes
993992 self ._callback = callback
994993 self ._images = images
994+ self ._big_file_queues = []
995+ self ._small_file_queues = []
995996
996997 def validate_item_names (self ):
997998 if self ._item_names :
@@ -1052,7 +1053,115 @@ def coroutine_wrapper(coroutine):
10521053 loop .close ()
10531054 return count
10541055
1056+ async def _download_big_annotation (self , item , export_path ):
1057+ postfix = self .get_postfix ()
1058+ response = await self ._backend_client .download_big_annotation (
1059+ item = item ,
1060+ team_id = self ._project .team_id ,
1061+ project_id = self ._project .id ,
1062+ folder_id = self ._folder .uuid ,
1063+ reporter = self .reporter ,
1064+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1065+ postfix = postfix ,
1066+ callback = self ._callback ,
1067+ )
1068+
1069+ return
1070+
1071+ async def download_big_annotations (self , queue_idx , export_path ):
1072+ while True :
1073+ cur_queue = self ._big_file_queues [queue_idx ]
1074+ item = await cur_queue .get ()
1075+ cur_queue .task_done ()
1076+ if item :
1077+ await self ._download_big_annotation (item , export_path )
1078+ else :
1079+ cur_queue .put_nowait (None )
1080+ break
1081+
1082+ async def download_small_annotations (self , queue_idx , export_path ):
1083+ max_chunk_size = 50000
1084+
1085+ cur_queue = self ._small_file_queues [queue_idx ]
1086+
1087+ items = []
1088+ i = 0
1089+ item = ""
1090+
1091+ postfix = self .get_postfix ()
1092+ while item is not None :
1093+ item = await cur_queue .get ()
1094+ if not item :
1095+ await self ._backend_client .download_small_annotations (
1096+ team_id = self ._project .team_id ,
1097+ project_id = self ._project .id ,
1098+ folder_id = self ._folder .uuid ,
1099+ items = items ,
1100+ reporter = self .reporter ,
1101+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1102+ postfix = postfix ,
1103+ callback = self ._callback ,
1104+ )
1105+ await cur_queue .put (None )
1106+ break
1107+
1108+ items .append (item )
1109+ if len (items ) == max_chunk_size :
1110+ await self ._backend_client .download_small_annotations (
1111+ team_id = self ._project .team_id ,
1112+ project_id = self ._project .id ,
1113+ folder_id = self ._folder .uuid ,
1114+ items = items ,
1115+ reporter = self .reporter ,
1116+ download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1117+ postfix = postfix ,
1118+ callback = self ._callback ,
1119+ )
1120+ items = []
1121+
1122+
1123+ async def distribute_to_queues (self , item_names , sm_queue_id , l_queue_id , folder_id ):
1124+
1125+ team_id = self ._project .team_id
1126+ project_id = self ._project .id
1127+
1128+ resp = self ._backend_client .sort_items_by_size (item_names , team_id , project_id , folder_id )
1129+
1130+ for item in resp ['large' ]:
1131+ await self ._big_file_queues [l_queue_id ].put (item )
1132+
1133+ for item in resp ['small' ]:
1134+ await self ._small_file_queues [sm_queue_id ].put (item ['name' ])
1135+
1136+
1137+ await self ._big_file_queues [l_queue_id ].put (None )
1138+ await self ._small_file_queues [sm_queue_id ].put (None )
1139+
1140+ async def run_workers (self , item_names , folder_id , export_path ):
1141+ try :
1142+ self ._big_file_queues .append (asyncio .Queue ())
1143+ self ._small_file_queues .append (asyncio .Queue ())
1144+ small_file_queue_idx = len (self ._small_file_queues ) - 1
1145+ big_file_queue_idx = len (self ._big_file_queues ) - 1
1146+
1147+ res = await asyncio .gather (
1148+ self .distribute_to_queues (item_names , small_file_queue_idx , big_file_queue_idx , folder_id ),
1149+ self .download_big_annotations (big_file_queue_idx , export_path ),
1150+ self .download_big_annotations (big_file_queue_idx , export_path ),
1151+ self .download_big_annotations (big_file_queue_idx , export_path ),
1152+ self .download_small_annotations (small_file_queue_idx , export_path ),
1153+ return_exceptions = True
1154+ )
1155+
1156+ except Exception as e :
1157+ self .reporter .log_error (f"Error { str (e )} " )
1158+
1159+
1160+ def per_folder_execute (self , item_names , folder_id , export_path ):
1161+ asyncio .run (self .run_workers (item_names , folder_id , export_path ))
1162+
10551163 def execute (self ):
1164+
10561165 if self .is_valid ():
10571166 export_path = str (
10581167 self .destination
@@ -1064,6 +1173,7 @@ def execute(self):
10641173 f"Downloading the annotations of the requested items to { export_path } \n This might take a while…"
10651174 )
10661175 self .reporter .start_spinner ()
1176+
10671177 folders = []
10681178 if self ._folder .is_root and self ._recursive :
10691179 folders = self ._folders .get_all (
@@ -1072,6 +1182,7 @@ def execute(self):
10721182 )
10731183 folders .append (self ._folder )
10741184 postfix = self .get_postfix ()
1185+
10751186 import nest_asyncio
10761187 import platform
10771188
@@ -1081,60 +1192,35 @@ def execute(self):
10811192 nest_asyncio .apply ()
10821193
10831194 if not folders :
1084- loop = asyncio .new_event_loop ()
1085- if not self ._item_names :
1086- condition = (
1087- Condition ("team_id" , self ._project .team_id , EQ )
1088- & Condition ("project_id" , self ._project .id , EQ )
1089- & Condition ("folder_id" , self ._folder .uuid , EQ )
1090- )
1091- item_names = [item .name for item in self ._images .get_all (condition )]
1092- else :
1093- item_names = self ._item_names
1094- count = loop .run_until_complete (
1095- self ._backend_client .download_annotations (
1096- team_id = self ._project .team_id ,
1097- project_id = self ._project .id ,
1098- folder_id = self ._folder .uuid ,
1099- items = item_names ,
1100- reporter = self .reporter ,
1101- download_path = f"{ export_path } { '/' + self ._folder .name if not self ._folder .is_root else '' } " ,
1102- postfix = postfix ,
1103- callback = self ._callback ,
1104- )
1105- )
1106- else :
1107- with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
1108- coroutines = []
1109- for folder in folders :
1110- if not self ._item_names :
1111- condition = (
1112- Condition ("team_id" , self ._project .team_id , EQ )
1113- & Condition ("project_id" , self ._project .id , EQ )
1114- & Condition ("folder_id" , folder .uuid , EQ )
1115- )
1116- item_names = [
1117- item .name for item in self ._images .get_all (condition )
1118- ]
1119- else :
1120- item_names = self ._item_names
1121- coroutines .append (
1122- self ._backend_client .download_annotations (
1123- team_id = self ._project .team_id ,
1124- project_id = self ._project .id ,
1125- folder_id = folder .uuid ,
1126- items = item_names ,
1127- reporter = self .reporter ,
1128- download_path = f"{ export_path } { '/' + folder .name if not folder .is_root else '' } " , # noqa
1129- postfix = postfix ,
1130- callback = self ._callback ,
1131- )
1195+ folders .append (self ._folder )
1196+
1197+
1198+ with concurrent .futures .ThreadPoolExecutor (max_workers = 5 ) as executor :
1199+ futures = []
1200+
1201+ for folder in folders :
1202+ if not self ._item_names :
1203+ condition = (
1204+ Condition ("team_id" , self ._project .team_id , EQ )
1205+ & Condition ("project_id" , self ._project .id , EQ )
1206+ & Condition ("folder_id" , folder .uuid , EQ )
11321207 )
1133- count = sum (
1134- [i for i in executor .map (self .coroutine_wrapper , coroutines )]
1135- )
1208+ item_names = [item .name for item in self ._images .get_all (condition )]
1209+ else :
1210+ item_names = self ._item_names
1211+
1212+ new_export_path = export_path
1213+ if folder .name != 'root' :
1214+ new_export_path += f'/{ folder .name } '
1215+ executor .submit (self .per_folder_execute , item_names , folder .uuid , new_export_path )
1216+
1217+
1218+ for future in concurrent .futures .as_completed (futures ):
1219+ print ("asd" )
1220+
11361221
11371222 self .reporter .stop_spinner ()
1223+ count = self .get_items_count (export_path )
11381224 self .reporter .log_info (f"Downloaded annotations for { count } items." )
11391225 self .download_annotation_classes (export_path )
11401226 self ._response .data = os .path .abspath (export_path )
0 commit comments