Skip to content

Commit 581eaba

Browse files
authored
Merge pull request #581 from superannotateai/nest_async
Nest asyncio fix
2 parents addc7bc + c92f4af commit 581eaba

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

src/superannotate/lib/core/usecases/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from lib.core.usecases.items import * # noqa: F403 F401
88
from lib.core.usecases.models import * # noqa: F403 F401
99
from lib.core.usecases.projects import * # noqa: F403 F401
10+
11+
import nest_asyncio
12+
nest_asyncio.apply()

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import boto3
2727
import jsonschema.validators
2828
import lib.core as constants
29-
import nest_asyncio
3029
from jsonschema import Draft7Validator
3130
from jsonschema import ValidationError
3231
from lib.core.conditions import Condition
@@ -391,7 +390,6 @@ def execute(self):
391390
len(items_to_upload), description="Uploading Annotations"
392391
)
393392
try:
394-
nest_asyncio.apply()
395393
asyncio.run(self.run_workers(items_to_upload))
396394
except Exception:
397395
logger.debug(traceback.format_exc())
@@ -737,7 +735,6 @@ def execute(self):
737735
except KeyError:
738736
missing_annotations.append(name)
739737
try:
740-
nest_asyncio.apply()
741738
asyncio.run(self.run_workers(items_to_upload))
742739
except Exception as e:
743740
logger.debug(e)
@@ -935,7 +932,6 @@ def execute(self):
935932
json.dump(annotation_json, annotation_file)
936933
size = annotation_file.tell()
937934
annotation_file.seek(0)
938-
nest_asyncio.apply()
939935
if size > BIG_FILE_THRESHOLD:
940936
uploaded = asyncio.run(
941937
self._service_provider.annotations.upload_big_annotation(
@@ -1550,7 +1546,6 @@ def execute(self):
15501546
large_items = list(filter(lambda item: item.id in large_item_ids, items))
15511547
small_items = list(filter(lambda item: item.id in small_items_ids, items))
15521548
try:
1553-
nest_asyncio.apply()
15541549
annotations = asyncio.run(self.run_workers(large_items, small_items))
15551550
except Exception as e:
15561551
logger.error(e)
@@ -1735,7 +1730,6 @@ def execute(self):
17351730
).data
17361731
if not folders:
17371732
folders.append(self._folder)
1738-
nest_asyncio.apply()
17391733
for folder in folders:
17401734
if self._item_names:
17411735
items = get_or_raise(

tests/unit/test_async_functions.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
sa = SAClient()
88

99

10+
class DummyIterator:
11+
def __init__(self, delay, to):
12+
self.delay = delay
13+
self.i = 0
14+
self.to = to
15+
16+
def __aiter__(self):
17+
return self
18+
19+
async def __anext__(self):
20+
i = self.i
21+
if i >= self.to:
22+
raise StopAsyncIteration
23+
self.i += 1
24+
if i:
25+
await asyncio.sleep(self.delay)
26+
return i
27+
28+
1029
class TestAsyncFunctions(TestCase):
1130
PROJECT_NAME = "TestAsync"
1231
PROJECT_DESCRIPTION = "Desc"
@@ -26,31 +45,35 @@ def setUpClass(cls):
2645
def tearDownClass(cls):
2746
sa.delete_project(cls.PROJECT_NAME)
2847

48+
@staticmethod
49+
async def nested():
50+
annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME)
51+
assert len(annotations) == 4
52+
2953
def test_get_annotations_in_running_event_loop(self):
3054
async def _test():
3155
annotations = sa.get_annotations(self.PROJECT_NAME)
3256
assert len(annotations) == 4
3357
asyncio.run(_test())
3458

35-
def test_multiple_get_annotations_in_running_event_loop(self):
36-
# TODO add handling of nested loop
37-
async def nested():
38-
sa.attach_items(self.PROJECT_NAME, self.ATTACH_PAYLOAD)
39-
annotations = sa.get_annotations(self.PROJECT_NAME)
40-
assert len(annotations) == 4
41-
async def create_task_test():
42-
import nest_asyncio
43-
nest_asyncio.apply()
44-
task1 = asyncio.create_task(nested())
45-
task2 = asyncio.create_task(nested())
59+
def test_create_task_get_annotations_in_running_event_loop(self):
60+
async def _test():
61+
task1 = asyncio.create_task(self.nested())
62+
task2 = asyncio.create_task(self.nested())
4663
await task1
4764
await task2
48-
asyncio.run(create_task_test())
65+
asyncio.run(_test())
66+
67+
def test_gather_get_annotations_in_running_event_loop(self):
68+
async def gather_test():
69+
await asyncio.gather(self.nested(), self.nested())
70+
asyncio.run(gather_test())
4971

72+
def test_gather_async_for(self):
5073
async def gather_test():
51-
import nest_asyncio
52-
nest_asyncio.apply()
53-
await asyncio.gather(nested(), nested())
74+
async for _ in DummyIterator(delay=0.01, to=2):
75+
annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME)
76+
assert len(annotations) == 4
5477
asyncio.run(gather_test())
5578

5679
def test_upload_annotations_in_running_event_loop(self):

0 commit comments

Comments
 (0)