Skip to content

Commit f8f8181

Browse files
authored
Merge pull request #596 from superannotateai/1865_run_prediction
Fixed run_prediction
2 parents ce9d6a5 + 77b7e4c commit f8f8181

File tree

4 files changed

+36
-49
lines changed

4 files changed

+36
-49
lines changed

src/superannotate/lib/core/usecases/models.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def execute(self):
479479
images = self._service_provider.items.list_by_names(
480480
project=self._project, folder=self._folder, names=self._images_list
481481
).data
482-
image_ids = [image.uuid for image in images]
482+
image_ids = [image.id for image in images]
483483
image_names = [image.name for image in images]
484484

485485
if not len(image_names):
@@ -502,36 +502,36 @@ def execute(self):
502502
ml_model_id=ml_model.id,
503503
image_ids=image_ids,
504504
)
505-
if not res.ok:
506-
return self._response.data
507-
508-
success_images = []
509-
failed_images = []
510-
while len(success_images) + len(failed_images) != len(image_ids):
511-
images_metadata = self._service_provider.items.list_by_names(
512-
project=self._project, folder=self._folder, names=self._images_list
513-
).data
514-
515-
success_images = [
516-
img.name
517-
for img in images_metadata
518-
if img.prediction_status
519-
== constances.SegmentationStatus.COMPLETED.value
520-
]
521-
failed_images = [
522-
img.name
523-
for img in images_metadata
524-
if img.prediction_status
525-
== constances.SegmentationStatus.FAILED.value
526-
]
527-
528-
complete_images = success_images + failed_images
529-
logger.info(
530-
f"prediction complete on {len(complete_images)} / {len(image_ids)} images"
531-
)
532-
time.sleep(5)
505+
if res.ok:
506+
success_images = []
507+
failed_images = []
508+
while len(success_images) + len(failed_images) != len(image_ids):
509+
images_metadata = self._service_provider.items.list_by_names(
510+
project=self._project, folder=self._folder, names=self._images_list
511+
).data
512+
513+
success_images = [
514+
img.name
515+
for img in images_metadata
516+
if img.prediction_status
517+
== constances.SegmentationStatus.COMPLETED.value
518+
]
519+
failed_images = [
520+
img.name
521+
for img in images_metadata
522+
if img.prediction_status
523+
== constances.SegmentationStatus.FAILED.value
524+
]
525+
526+
complete_images = success_images + failed_images
527+
logger.info(
528+
f"prediction complete on {len(complete_images)} / {len(image_ids)} images"
529+
)
530+
time.sleep(5)
533531

534-
self._response.data = (success_images, failed_images)
532+
self._response.data = (success_images, failed_images)
533+
else:
534+
self._response.errors = res.error
535535
return self._response
536536

537537

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
import io
2-
from typing import List
3-
4-
from lib.core.conditions import Condition
5-
from lib.core.entities import ProjectEntity
62
from lib.core.entities import S3FileEntity
73
from lib.core.repositories import BaseS3Repository
84

@@ -22,12 +18,3 @@ def insert(self, entity: S3FileEntity) -> S3FileEntity:
2218
data["Metadata"] = temp
2319
self.bucket.put_object(**data)
2420
return entity
25-
26-
def update(self, entity: ProjectEntity):
27-
self._service.update_project(entity.to_dict())
28-
29-
def delete(self, uuid: int):
30-
self._service.delete_project(uuid)
31-
32-
def get_all(self, condition: Condition = None) -> List[ProjectEntity]:
33-
pass

src/superannotate/lib/infrastructure/serviceprovider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def run_prediction(
195195
self.URL_PREDICTION,
196196
"post",
197197
data={
198+
"team_id": project.team_id,
198199
"project_id": project.id,
199200
"ml_model_id": ml_model_id,
200201
"image_ids": image_ids,

tests/integration/test_ml_funcs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
2-
import time
32
from os.path import dirname
43

5-
import pytest
64
from src.superannotate import SAClient
5+
from src.superannotate import AppException
76
from tests.integration.base import BaseTestCase
7+
import pytest
88

99
sa = SAClient()
1010

@@ -23,17 +23,16 @@ def folder_path(self):
2323
return os.path.join(dirname(dirname(__file__)), self.TEST_FOLDER_PATH)
2424

2525
def test_run_prediction_with_non_exist_images(self):
26-
with pytest.raises(Exception) as e:
26+
with self.assertRaisesRegexp(AppException, 'No valid image names were provided.'):
2727
sa.run_prediction(
28-
self.PROJECT_NAME, ["NonExistantImage.jpg"], self.MODEL_NAME
28+
self.PROJECT_NAME, ["NotExistingImage.jpg"], self.MODEL_NAME
2929
)
3030

31-
@pytest.mark.skip(reason="Need to adjust")
31+
@pytest.mark.skip(reason="Test skipped due to long execution")
3232
def test_run_prediction_for_all_images(self):
3333
sa.upload_images_from_folder_to_project(
3434
project=self.PROJECT_NAME, folder_path=self.folder_path
3535
)
36-
time.sleep(2)
3736
image_names_vector = [i["name"] for i in sa.search_items(self.PROJECT_NAME)]
3837
succeeded_images, failed_images = sa.run_prediction(
3938
self.PROJECT_NAME, image_names_vector, self.MODEL_NAME

0 commit comments

Comments
 (0)