Skip to content

Commit e18dae2

Browse files
authored
Merge pull request #588 from superannotateai/1855_annotations
updated get/download annotations
2 parents 5714acb + 51a1edb commit e18dae2

File tree

5 files changed

+68
-111
lines changed

5 files changed

+68
-111
lines changed

src/superannotate/lib/core/serviceproviders.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,9 @@ async def list_small_annotations(
332332
raise NotImplementedError
333333

334334
@abstractmethod
335-
def sort_items_by_size(
335+
def get_upload_chunks(
336336
self,
337337
project: entities.ProjectEntity,
338-
folder: entities.FolderEntity,
339338
item_ids: List[int],
340339
) -> Dict[str, List]:
341340
raise NotImplementedError

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 52 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,6 @@ def __init__(
13871387
self._item_names = item_names
13881388
self._item_names_provided = True
13891389
self._big_annotations_queue = None
1390-
self._small_annotations_queue = None
13911390

13921391
def validate_project_type(self):
13931392
if self._project.type == constants.ProjectType.PIXEL.value:
@@ -1436,29 +1435,18 @@ async def get_big_annotation(self):
14361435
break
14371436
return large_annotations
14381437

1439-
async def get_small_annotations(self):
1440-
small_annotations = []
1441-
while True:
1442-
items = await self._small_annotations_queue.get()
1443-
if items:
1444-
annotations = (
1445-
await self._service_provider.annotations.list_small_annotations(
1446-
project=self._project,
1447-
folder=self._folder,
1448-
item_ids=[i.id for i in items],
1449-
reporter=self.reporter,
1450-
)
1451-
)
1452-
small_annotations.extend(annotations)
1453-
else:
1454-
await self._small_annotations_queue.put(None)
1455-
break
1456-
return small_annotations
1438+
async def get_small_annotations(self, item_ids: List[int]):
1439+
return await self._service_provider.annotations.list_small_annotations(
1440+
project=self._project,
1441+
folder=self._folder,
1442+
item_ids=item_ids,
1443+
reporter=self.reporter,
1444+
)
14571445

14581446
async def run_workers(
14591447
self,
14601448
big_annotations: List[BaseItemEntity],
1461-
small_annotations: List[BaseItemEntity],
1449+
small_annotations: List[List[Dict]],
14621450
):
14631451
annotations = []
14641452
if big_annotations:
@@ -1481,26 +1469,16 @@ async def run_workers(
14811469
)
14821470
)
14831471
if small_annotations:
1484-
self._small_annotations_queue = asyncio.Queue()
1485-
small_chunks = divide_to_chunks(
1486-
small_annotations, size=self._config.ANNOTATION_CHUNK_SIZE
1487-
)
1488-
for chunk in small_chunks:
1489-
self._small_annotations_queue.put_nowait(chunk)
1490-
self._small_annotations_queue.put_nowait(None)
1491-
1492-
annotations.extend(
1493-
list(
1494-
itertools.chain.from_iterable(
1495-
await asyncio.gather(
1496-
*[
1497-
self.get_small_annotations()
1498-
for _ in range(self._config.MAX_COROUTINE_COUNT)
1499-
]
1500-
)
1501-
)
1472+
for chunks in divide_to_chunks(
1473+
small_annotations, self._config.MAX_COROUTINE_COUNT
1474+
):
1475+
tasks = []
1476+
for chunk in chunks:
1477+
tasks.append(self.get_small_annotations([i["id"] for i in chunk]))
1478+
annotations.extend(
1479+
list(itertools.chain.from_iterable(await asyncio.gather(*tasks)))
15021480
)
1503-
)
1481+
15041482
return list(filter(None, annotations))
15051483

15061484
def execute(self):
@@ -1523,7 +1501,6 @@ def execute(self):
15231501
items = get_or_raise(self._service_provider.items.list(condition))
15241502
else:
15251503
items = []
1526-
id_item_map = {i.id: i for i in items}
15271504
if not items:
15281505
logger.info("No annotations to download.")
15291506
self._response.data = []
@@ -1533,18 +1510,20 @@ def execute(self):
15331510
f"Getting {items_count} annotations from "
15341511
f"{self._project.name}{f'/{self._folder.name}' if self._folder.name != 'root' else ''}."
15351512
)
1513+
id_item_map = {i.id: i for i in items}
15361514
self.reporter.start_progress(
15371515
items_count,
15381516
disable=logger.level > logging.INFO or self.reporter.log_enabled,
15391517
)
1540-
1541-
sort_response = self._service_provider.annotations.sort_items_by_size(
1542-
project=self._project, folder=self._folder, item_ids=list(id_item_map)
1518+
sort_response = self._service_provider.annotations.get_upload_chunks(
1519+
project=self._project,
1520+
item_ids=list(id_item_map),
15431521
)
15441522
large_item_ids = set(map(itemgetter("id"), sort_response["large"]))
1545-
small_items_ids = set(map(itemgetter("id"), sort_response["small"]))
1546-
large_items = list(filter(lambda item: item.id in large_item_ids, items))
1547-
small_items = list(filter(lambda item: item.id in small_items_ids, items))
1523+
large_items: List[BaseItemEntity] = list(
1524+
filter(lambda item: item.id in large_item_ids, items)
1525+
)
1526+
small_items: List[List[dict]] = sort_response["small"]
15481527
try:
15491528
annotations = asyncio.run(self.run_workers(large_items, small_items))
15501529
except Exception as e:
@@ -1580,7 +1559,6 @@ def __init__(
15801559
self._service_provider = service_provider
15811560
self._callback = callback
15821561
self._big_file_queue = None
1583-
self._small_file_queue = None
15841562

15851563
def validate_item_names(self):
15861564
if self._item_names:
@@ -1659,28 +1637,24 @@ async def download_big_annotations(self, export_path):
16591637
self._big_file_queue.put_nowait(None)
16601638
break
16611639

1662-
async def download_small_annotations(self, export_path, folder: FolderEntity):
1640+
async def download_small_annotations(
1641+
self, item_ids: List[int], export_path, folder: FolderEntity
1642+
):
16631643
postfix = self.get_postfix()
1664-
while True:
1665-
items = await self._small_file_queue.get()
1666-
if items:
1667-
await self._service_provider.annotations.download_small_annotations(
1668-
project=self._project,
1669-
folder=folder,
1670-
item_ids=[i.id for i in items],
1671-
reporter=self.reporter,
1672-
download_path=f"{export_path}{'/' + self._folder.name if not self._folder.is_root else ''}",
1673-
postfix=postfix,
1674-
callback=self._callback,
1675-
)
1676-
else:
1677-
self._small_file_queue.put_nowait(None)
1678-
break
1644+
await self._service_provider.annotations.download_small_annotations(
1645+
project=self._project,
1646+
folder=folder,
1647+
item_ids=item_ids,
1648+
reporter=self.reporter,
1649+
download_path=f"{export_path}{'/' + self._folder.name if not self._folder.is_root else ''}",
1650+
postfix=postfix,
1651+
callback=self._callback,
1652+
)
16791653

16801654
async def run_workers(
16811655
self,
16821656
big_annotations: List[BaseItemEntity],
1683-
small_annotations: List[BaseItemEntity],
1657+
small_annotations: List[List[dict]],
16841658
folder: FolderEntity,
16851659
export_path,
16861660
):
@@ -1697,19 +1671,17 @@ async def run_workers(
16971671
)
16981672

16991673
if small_annotations:
1700-
self._small_file_queue = asyncio.Queue()
1701-
small_chunks = divide_to_chunks(
1702-
small_annotations, size=self._config.ANNOTATION_CHUNK_SIZE
1703-
)
1704-
for chunk in small_chunks:
1705-
self._small_file_queue.put_nowait(chunk)
1706-
self._small_file_queue.put_nowait(None)
1707-
await asyncio.gather(
1708-
*[
1709-
self.download_small_annotations(export_path, folder)
1710-
for _ in range(self._config.MAX_COROUTINE_COUNT)
1711-
]
1712-
)
1674+
for chunks in divide_to_chunks(
1675+
small_annotations, self._config.MAX_COROUTINE_COUNT
1676+
):
1677+
tasks = []
1678+
for chunk in chunks:
1679+
tasks.append(
1680+
self.download_small_annotations(
1681+
[i["id"] for i in chunk], export_path, folder
1682+
)
1683+
)
1684+
await asyncio.gather(*tasks)
17131685

17141686
def execute(self):
17151687
if self.is_valid():
@@ -1749,19 +1721,15 @@ def execute(self):
17491721
new_export_path += f"/{folder.name}"
17501722

17511723
id_item_map = {i.id: i for i in items}
1752-
sort_response = self._service_provider.annotations.sort_items_by_size(
1724+
sort_response = self._service_provider.annotations.get_upload_chunks(
17531725
project=self._project,
1754-
folder=self._folder,
17551726
item_ids=list(id_item_map),
17561727
)
17571728
large_item_ids = set(map(itemgetter("id"), sort_response["large"]))
1758-
small_items_ids = set(map(itemgetter("id"), sort_response["small"]))
1759-
large_items = list(
1729+
large_items: List[BaseItemEntity] = list(
17601730
filter(lambda item: item.id in large_item_ids, items)
17611731
)
1762-
small_items = list(
1763-
filter(lambda item: item.id in small_items_ids, items)
1764-
)
1732+
small_items: List[List[dict]] = sort_response["small"]
17651733
try:
17661734
asyncio.run(
17671735
self.run_workers(

src/superannotate/lib/infrastructure/services/annotation.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class AnnotationService(BaseAnnotationService):
27-
ASSETS_PROVIDER_VERSION = "v2"
27+
ASSETS_PROVIDER_VERSION = "v2.01"
2828
DEFAULT_CHUNK_SIZE = 5000
2929

3030
URL_GET_ANNOTATIONS = "items/annotations/download"
@@ -153,33 +153,24 @@ async def list_small_annotations(
153153
params=query_params,
154154
)
155155

156-
def sort_items_by_size(
156+
def get_upload_chunks(
157157
self,
158158
project: entities.ProjectEntity,
159-
folder: entities.FolderEntity,
160159
item_ids: List[int],
161160
) -> Dict[str, List]:
162-
chunk_size = 2000
163-
query_params = {
164-
"project_id": project.id,
165-
"folder_id": folder.id,
166-
}
167-
168161
response_data = {"small": [], "large": []}
169-
for i in range(0, len(item_ids), chunk_size):
170-
body = {
171-
"item_ids": item_ids[i : i + chunk_size], # noqa
172-
} # noqa
173-
response = self.client.request(
174-
url=urljoin(self.assets_provider_url, self.URL_CLASSIFY_ITEM_SIZE),
175-
method="POST",
176-
params=query_params,
177-
data=body,
178-
)
179-
if not response.ok:
180-
raise AppException(response.error)
181-
response_data["small"].extend(response.data.get("small", []))
182-
response_data["large"].extend(response.data.get("large", []))
162+
response = self.client.request(
163+
url=urljoin(self.assets_provider_url, self.URL_CLASSIFY_ITEM_SIZE),
164+
method="POST",
165+
params={"project_id": project.id, "limit": len(item_ids)},
166+
data={"item_ids": item_ids},
167+
)
168+
if not response.ok:
169+
raise AppException(response.error)
170+
response_data["small"] = [
171+
i["data"] for i in response.data.get("small", {}).values()
172+
]
173+
response_data["large"] = response.data.get("large", [])
183174
return response_data
184175

185176
async def download_big_annotation(

tests/integration/annotations/test_get_annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_get_annotations10000(self):
119119
[
120120
{"name": f"example_image_{i}.jpg", "url": f"url_{i}"}
121121
for i in range(count)
122-
], # noqa
122+
],
123123
)
124124
assert len(sa.search_items(self.PROJECT_NAME)) == count
125125
a = sa.get_annotations(self.PROJECT_NAME)

tests/unit/test_async_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from superannotate import SAClient
55

6-
76
sa = SAClient()
87

98

0 commit comments

Comments
 (0)