Skip to content

Commit 2e1a150

Browse files
committed
Fix neural networks docs
1 parent c300eb5 commit 2e1a150

File tree

2 files changed

+46
-22
lines changed

2 files changed

+46
-22
lines changed

superannotate/ml/ml_funcs.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333
@model_metadata
3434
def run_prediction(project, images_list, model):
3535
"""This function runs smart prediction on given list of images from a given project using the neural network of your choice
36+
3637
:param project: the project in which the target images are uploaded.
3738
:type project: str or dict
3839
:param images_list: the list of image names on which smart prediction has to be run
3940
:type images_list: list of str
4041
:param model: the name of the model that should be used for running smart prediction
4142
:type model: str or dict
42-
:out res: tupe of two lists, list of images on which the prediction has succeded and failed respectively
43-
:rtype res: tuple
43+
:return: tupe of two lists, list of images on which the prediction has succeded and failed respectively
44+
:rtype: tuple
4445
"""
4546

4647
if not isinstance(project, dict):
@@ -114,7 +115,7 @@ def run_segmentation(project, images_list, model):
114115
:type project: str or dict
115116
:param model : The model name or metadata of the model
116117
:type model : str or dict
117-
:out res: tupe of two lists, list of images on which the prediction has succeded and failed respectively
118+
:return: tupe of two lists, list of images on which the prediction has succeded and failed respectively
118119
:rtype res: tuple
119120
"""
120121

@@ -194,22 +195,23 @@ def run_training(
194195
log=False
195196
):
196197
"""Runs neural network training
198+
197199
:param project: project or list of projects that contain the training images
198200
:type project: str, dict or list of dict
199-
:param base_model : base model on which the new network will be trained
200-
:type base_model : str or dict
201-
:param model_name : name of the new model
202-
:type model_name : str
201+
:param base_model: base model on which the new network will be trained
202+
:type base_model: str or dict
203+
:param model_name: name of the new model
204+
:type model_name: str
203205
:param model_description: description of the new model
204206
:type model_description: str
205-
:param task : The model training task
206-
:type task : str
207+
:param task: The model training task
208+
:type task: str
207209
:param hyperparameters: hyperparameters that should be used in training
208210
:type hyperparameters: dict
209211
:param log: If true will log training metrics in the stdout
210212
:type log: boolean
211-
:out new_model: the metadata of the newly created model
212-
:rtype new_model: dict
213+
:return: the metadata of the newly created model
214+
:rtype: dict
213215
"""
214216

215217
project_ids = None
@@ -350,12 +352,13 @@ def run_training(
350352
@Trackable
351353
@model_metadata
352354
def stop_model_training(model):
353-
'''This function will stop training model provided by either name or metadata, and return the ID
355+
"""This function will stop training model provided by either name or metadata, and return the ID
356+
354357
:param model: The name or the metadata of the model the training of which the user needs to terminate
355-
:type model: str or dict
356-
:out model: the metadata of the now, stopped model
357-
:rtype out: dict
358-
'''
358+
:type model: dict
359+
:return: the metadata of the now, stopped model
360+
:rtype: dict
361+
"""
359362
params = {"team_id": _api.team_id}
360363
response = _api.send_request(
361364
req_type="POST",
@@ -373,6 +376,7 @@ def stop_model_training(model):
373376
@Trackable
374377
def plot_model_metrics(metric_json_list):
375378
"""plots the metrics generated by neural network using plotly
379+
376380
:param metric_json_list: list of <model_name>.json files
377381
:type metric_json_list: list of str
378382
"""
@@ -445,13 +449,15 @@ def get_plottable_cols(df):
445449
def download_model(model, output_dir):
446450
"""Downloads the neural network and related files
447451
which are the <model_name>.pth/pkl. <model_name>.json, <model_name>.yaml, classes_mapper.json
452+
448453
:param model: the model that needs to be downloaded
449454
:type model: str
450455
:param output_dir: the directiory in which the files will be saved
451456
:type output_dir: str
452-
:out model: the metadata of the model that was deleted
453-
:rtype model: dict
457+
:return: the metadata of the model
458+
:rtype: dict
454459
"""
460+
455461
if not os.path.isdir(output_dir):
456462
os.makedirs(output_dir, exist_ok=True)
457463

@@ -506,11 +512,13 @@ def download_model(model, output_dir):
506512
@Trackable
507513
@model_metadata
508514
def delete_model(model):
509-
'''This function deletes the provided model
515+
"""This function deletes the provided model
516+
510517
:param model: the model to be deleted
511-
:out model: the metadata of the model that was deleted
512-
:rtype out: dict
513-
'''
518+
:type model: dict
519+
:return: the metadata of the model that was deleted
520+
:rtype: dict
521+
"""
514522
params = {"team_id": _api.team_id}
515523
response = _api.send_request(
516524
req_type="DELETE", path=f'/ml_model/{model["id"]}', params=params

superannotate/ml/ml_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ def search_models(
1616
task=None,
1717
include_global=True,
1818
):
19+
"""Search for ML models.
20+
21+
:param name: search string
22+
:type name: str
23+
:param type_: ml model type string
24+
:type type_: str
25+
:param project_id: project id
26+
:type project_id: str
27+
:param task: training task
28+
:type task: str
29+
:param include_global: include global ml models
30+
:type include_global: bool
31+
32+
:return: ml model metadata
33+
:rtype: list of dicts
34+
"""
1935
params = {
2036
"name": name,
2137
"team_id": _api.team_id,

0 commit comments

Comments
 (0)