Skip to content

Commit 2c071df

Browse files
authored
Merge branch 'friday' into 1855_annotations
2 parents 19cc32e + 5714acb commit 2c071df

File tree

7 files changed

+131
-35
lines changed

7 files changed

+131
-35
lines changed

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,8 @@ def create_annotation_class(
14001400
:type color: str
14011401
14021402
:param attribute_groups: list of attribute group dicts.
1403-
The values for the "group_type" key are "radio"|"checklist"|"text"|"numeric".
1403+
The values for the "group_type" key are "radio"|"checklist"|"text"|"numeric"|"ocr".
1404+
"ocr "group_type" key is only available for Vector projects.
14041405
Mandatory keys for each attribute group are
14051406
14061407
- "name"

src/superannotate/lib/core/usecases/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import nest_asyncio
12
from lib.core.usecases.annotations import * # noqa: F403 F401
23
from lib.core.usecases.classes import * # noqa: F403 F401
34
from lib.core.usecases.custom_fields import * # noqa: F403 F401
@@ -7,3 +8,5 @@
78
from lib.core.usecases.items import * # noqa: F403 F401
89
from lib.core.usecases.models import * # noqa: F403 F401
910
from lib.core.usecases.projects import * # noqa: F403 F401
11+
12+
nest_asyncio.apply()

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import boto3
2727
import jsonschema.validators
2828
import lib.core as constants
29-
import nest_asyncio
3029
from jsonschema import Draft7Validator
3130
from jsonschema import ValidationError
3231
from lib.core.conditions import Condition
@@ -391,7 +390,6 @@ def execute(self):
391390
len(items_to_upload), description="Uploading Annotations"
392391
)
393392
try:
394-
nest_asyncio.apply()
395393
asyncio.run(self.run_workers(items_to_upload))
396394
except Exception:
397395
logger.debug(traceback.format_exc())
@@ -737,7 +735,6 @@ def execute(self):
737735
except KeyError:
738736
missing_annotations.append(name)
739737
try:
740-
nest_asyncio.apply()
741738
asyncio.run(self.run_workers(items_to_upload))
742739
except Exception as e:
743740
logger.debug(e)
@@ -935,7 +932,6 @@ def execute(self):
935932
json.dump(annotation_json, annotation_file)
936933
size = annotation_file.tell()
937934
annotation_file.seek(0)
938-
nest_asyncio.apply()
939935
if size > BIG_FILE_THRESHOLD:
940936
uploaded = asyncio.run(
941937
self._service_provider.annotations.upload_big_annotation(
@@ -1529,7 +1525,6 @@ def execute(self):
15291525
)
15301526
small_items: List[List[dict]] = sort_response["small"]
15311527
try:
1532-
nest_asyncio.apply()
15331528
annotations = asyncio.run(self.run_workers(large_items, small_items))
15341529
except Exception as e:
15351530
logger.error(e)
@@ -1707,7 +1702,6 @@ def execute(self):
17071702
).data
17081703
if not folders:
17091704
folders.append(self._folder)
1710-
nest_asyncio.apply()
17111705
for folder in folders:
17121706
if self._item_names:
17131707
items = get_or_raise(

src/superannotate/lib/core/usecases/classes.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lib.core.conditions import CONDITION_EQ as EQ
77
from lib.core.entities import AnnotationClassEntity
88
from lib.core.entities import ProjectEntity
9+
from lib.core.entities.classes import GroupTypeEnum
910
from lib.core.enums import ProjectType
1011
from lib.core.exceptions import AppException
1112
from lib.core.serviceproviders import BaseServiceProvider
@@ -66,6 +67,13 @@ def validate_project_type(self):
6667
"Predefined tagging functionality is not supported for projects"
6768
f" of type {ProjectType.get_name(self._project.type)}."
6869
)
70+
if self._project.type != ProjectType.VECTOR:
71+
for g in self._annotation_class.attribute_groups:
72+
if g.group_type == GroupTypeEnum.OCR:
73+
raise AppException(
74+
f"OCR attribute group is not supported for project type "
75+
f"{ProjectType.get_name(self._project.type)}."
76+
)
6977

7078
def validate_default_value(self):
7179
if self._project.type == ProjectType.PIXEL.value and any(
@@ -109,13 +117,19 @@ def __init__(
109117
self._annotation_classes = annotation_classes
110118

111119
def validate_project_type(self):
112-
if self._project.type == ProjectType.PIXEL and any(
113-
[True for i in self._annotation_classes if i.type == "tag"]
114-
):
115-
raise AppException(
116-
f"Predefined tagging functionality is not supported"
117-
f" for projects of type {ProjectType.get_name(self._project.type)}."
118-
)
120+
if self._project.type != ProjectType.VECTOR:
121+
for c in self._annotation_classes:
122+
if self._project.type == ProjectType.PIXEL and c.type == "tag":
123+
raise AppException(
124+
f"Predefined tagging functionality is not supported"
125+
f" for projects of type {ProjectType.get_name(self._project.type)}."
126+
)
127+
for g in c.attribute_groups:
128+
if g.group_type == GroupTypeEnum.OCR:
129+
raise AppException(
130+
f"OCR attribute group is not supported for project type "
131+
f"{ProjectType.get_name(self._project.type)}."
132+
)
119133

120134
def validate_default_value(self):
121135
if self._project.type == ProjectType.PIXEL.value:

tests/integration/classes/test_classes_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,6 @@ def test_group_type_wrong_arg(self):
7878
"'radio',",
7979
"'checklist',",
8080
"'numeric',",
81-
"'text'",
81+
"'text',",
8282
"'ocr'",
8383
] == wrap_error(e).split()

tests/integration/classes/test_create_annotation_class.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,71 @@ def test_create_annotation_class(self):
245245
"Predefined tagging functionality is not supported for projects of type Video.",
246246
)
247247

248+
def test_create_annotation_class_via_ocr_group_type(self):
249+
with self.assertRaisesRegexp(
250+
AppException,
251+
f"OCR attribute group is not supported for project type {self.PROJECT_TYPE}.",
252+
):
253+
attribute_groups = [
254+
{
255+
"id": 21448,
256+
"class_id": 56820,
257+
"name": "Large",
258+
"group_type": "ocr",
259+
"is_multiselect": 0,
260+
"createdAt": "2020-09-29T10:39:39.000Z",
261+
"updatedAt": "2020-09-29T10:39:39.000Z",
262+
"attributes": [],
263+
}
264+
]
265+
sa.create_annotation_class(
266+
self.PROJECT_NAME,
267+
"test_add",
268+
"#FF0000",
269+
attribute_groups,
270+
class_type="tag",
271+
)
272+
273+
def test_create_annotation_class_via_json_and_ocr_group_type(self):
274+
with tempfile.TemporaryDirectory() as tmpdir_name:
275+
temp_path = f"{tmpdir_name}/new_classes.json"
276+
with open(temp_path, "w") as new_classes:
277+
new_classes.write(
278+
"""
279+
[
280+
{
281+
"id":56820,
282+
"project_id":7617,
283+
"name":"Personal vehicle",
284+
"color":"#547497",
285+
"count":18,
286+
"createdAt":"2020-09-29T10:39:39.000Z",
287+
"updatedAt":"2020-09-29T10:48:18.000Z",
288+
"type": "tag",
289+
"attribute_groups":[
290+
{
291+
"id":21448,
292+
"class_id":56820,
293+
"name":"Large",
294+
"group_type": "ocr",
295+
"is_multiselect":0,
296+
"createdAt":"2020-09-29T10:39:39.000Z",
297+
"updatedAt":"2020-09-29T10:39:39.000Z",
298+
"attributes":[]
299+
}
300+
]
301+
}
302+
]
303+
"""
304+
)
305+
with self.assertRaisesRegexp(
306+
AppException,
307+
f"OCR attribute group is not supported for project type {self.PROJECT_TYPE}.",
308+
):
309+
sa.create_annotation_classes_from_classes_json(
310+
self.PROJECT_NAME, temp_path
311+
)
312+
248313

249314
class TestCreateAnnotationClassPixel(BaseTestCase):
250315
PROJECT_NAME = "TestCreateAnnotationClassPixel"

tests/unit/test_async_functions.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
sa = SAClient()
88

99

10+
class DummyIterator:
11+
def __init__(self, delay, to):
12+
self.delay = delay
13+
self.i = 0
14+
self.to = to
15+
16+
def __aiter__(self):
17+
return self
18+
19+
async def __anext__(self):
20+
i = self.i
21+
if i >= self.to:
22+
raise StopAsyncIteration
23+
self.i += 1
24+
if i:
25+
await asyncio.sleep(self.delay)
26+
return i
27+
28+
1029
class TestAsyncFunctions(TestCase):
1130
PROJECT_NAME = "TestAsync"
1231
PROJECT_DESCRIPTION = "Desc"
@@ -26,37 +45,37 @@ def setUpClass(cls):
2645
def tearDownClass(cls):
2746
sa.delete_project(cls.PROJECT_NAME)
2847

48+
@staticmethod
49+
async def nested():
50+
annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME)
51+
assert len(annotations) == 4
52+
2953
def test_get_annotations_in_running_event_loop(self):
3054
async def _test():
3155
annotations = sa.get_annotations(self.PROJECT_NAME)
3256
assert len(annotations) == 4
3357

3458
asyncio.run(_test())
3559

36-
def test_multiple_get_annotations_in_running_event_loop(self):
37-
# TODO add handling of nested loop
38-
async def nested():
39-
sa.attach_items(self.PROJECT_NAME, self.ATTACH_PAYLOAD)
40-
annotations = sa.get_annotations(self.PROJECT_NAME)
41-
assert len(annotations) == 4
42-
43-
async def create_task_test():
44-
import nest_asyncio
60+
def test_create_task_get_annotations_in_running_event_loop(self):
61+
async def _test():
62+
task1 = asyncio.create_task(self.nested())
63+
task2 = asyncio.create_task(self.nested())
64+
await task1
65+
await task2
4566

46-
nest_asyncio.apply()
47-
task1 = asyncio.create_task(nested())
48-
task2 = asyncio.create_task(nested())
49-
await task1
50-
await task2
51-
52-
asyncio.run(create_task_test())
67+
asyncio.run(_test())
5368

69+
def test_gather_get_annotations_in_running_event_loop(self):
5470
async def gather_test():
55-
import nest_asyncio
56-
57-
nest_asyncio.apply()
58-
await asyncio.gather(nested(), nested())
71+
await asyncio.gather(self.nested(), self.nested())
72+
asyncio.run(gather_test())
5973

74+
def test_gather_async_for(self):
75+
async def gather_test():
76+
async for _ in DummyIterator(delay=0.01, to=2):
77+
annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME)
78+
assert len(annotations) == 4
6079
asyncio.run(gather_test())
6180

6281
def test_upload_annotations_in_running_event_loop(self):

0 commit comments

Comments
 (0)