diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index be97a4c9..b092747c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -166,6 +166,7 @@ jobs: env: LUXONISML_BUCKET: luxonis-test-bucket SUITE: ${{ matrix.suite }} + HUBAI_API_KEY: ${{ secrets.HUBAI_API_KEY }} run: pytest -x --cov --junitxml=junit.xml -o junit_family=legacy -m "${SUITE}" - name: Upload test results to Codecov diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9fecd399..792433fc 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -29,6 +29,9 @@ on: CODECOV_TOKEN: description: 'Codecov upload token' required: true + HUBAI_API_KEY: + description: 'HubAI API key' + required: true permissions: pull-requests: write @@ -92,6 +95,7 @@ jobs: working-directory: luxonis-train env: LUXONISML_BUCKET: luxonis-test-bucket + HUBAI_API_KEY: ${{ secrets.HUBAI_API_KEY }} run: pytest --cov --junitxml=junit.xml -o junit_family=legacy - name: Upload test results to Codecov diff --git a/configs/README.md b/configs/README.md index af7a3aa3..860329fd 100644 --- a/configs/README.md +++ b/configs/README.md @@ -28,7 +28,8 @@ You can create your own config or use/edit one of the examples. - [Trainer Tips](#trainer-tips) - [Exporter](#exporter) - [`ONNX`](#onnx) - - [Blob](#blob) + - [HubAI](#hubai) + - [Blob (Deprecated)](#blob-deprecated) - [Tuner](#tuner) - [Storage](#storage) - [ENVIRON](#environ) @@ -355,10 +356,13 @@ trainer: patience: 3 monitor: "val/loss" mode: "min" - - name: "ExportOnTrainEnd" + - name: "ConvertOnTrainEnd" - name: "TestOnTrainEnd" ``` +> [!NOTE] +> `ConvertOnTrainEnd` is the recommended callback for model conversion. It combines export, archive, and platform-specific conversion (blobconverter/HubAI SDK) into a single step. Use this instead of separate `ExportOnTrainEnd` and `ArchiveOnTrainEnd` callbacks. + ### Optimizer What optimizer to use for training. @@ -490,14 +494,15 @@ Here you can define configuration for exporting. | ------------------------ | --------------------------------- | ------------- | ---------------------------------------------------------------------------------------------- | | `name` | `str \| None` | `None` | Name of the exported model | | `input_shape` | `list\[int\] \| None` | `None` | Input shape of the model. If not provided, inferred from the dataset | -| `data_type` | `Literal["INT8", "FP16", "FP32"]` | `"FP16"` | Data type of the exported model. Only used for conversion to BLOB | +| `target_precision` | `Literal["INT8", "FP16", "FP32"]` | `"FP16"` | Data type of the exported model. Alias: `data_type` | | `reverse_input_channels` | `bool` | `True` | Whether to reverse the image channels in the exported model. Relevant for `BLOB` export | | `scale_values` | `list[float] \| None` | `None` | What scale values to use for input normalization. If not provided, inferred from augmentations | | `mean_values` | `list[float] \| None` | `None` | What mean values to use for input normalization. If not provided, inferred from augmentations | | `upload_to_run` | `bool` | `True` | Whether to upload the exported files to tracked run as artifact | | `upload_url` | `str \| None` | `None` | Exported model will be uploaded to this URL if specified | | `onnx` | `dict` | `{}` | Options specific for ONNX export. See [ONNX](#onnx) section for details | -| `blobconverter` | `dict` | `{}` | Options for converting to BLOB format. See [Blob](#blob) section for details | +| `hubai` | `dict` | `{}` | Options for HubAI SDK conversion. See [HubAI](#hubai) section for details | +| `blobconverter` | `dict` | `{}` | Options for converting to BLOB format (deprecated). See [Blob](#blob-deprecated) section | ### `ONNX` @@ -510,7 +515,37 @@ Option specific for `ONNX` export. | `disable_onnx_simplification` | `bool` | `False` | Disable ONNX simplification after export | | `unique_onnx_initializers` | `bool` | `False` | Re-assign names to identifiers after export to ensure they are per-block unique | -### `Blob` +### `HubAI` + +The [HubAI SDK](https://github.com/luxonis/hubai-sdk) provides model conversion for multiple platforms (RVC2, RVC3, RVC4, Hailo). +This is the recommended way to convert models for deployment. + +> [!NOTE] +> Requires `HUBAI_API_KEY` environment variable to be set. + +| Key | Type | Default value | Description | +| --------------------- | ------------------------------------------ | ------------- | ----------------------------------------------------------------- | +| `active` | `bool` | `False` | Whether to use HubAI SDK for conversion | +| `platform` | `Literal["rvc2", "rvc3", "rvc4", "hailo"]` | `None` | Target platform for conversion. Required when `active` is `True` | +| `delete_remote_model` | `bool` | `False` | Clean up by deleting remote uploaded variant in HubAI | +| `params` | `dict` | `{}` | Additional parameters passed to the HubAI SDK conversion function | + +**Example:** + +```yaml +exporter: + target_precision: fp16 + hubai: + active: true + platform: rvc2 + params: + superblob: True +``` + +### `Blob` (Deprecated) + +> [!WARNING] +> `blobconverter` is deprecated and only supports RVC2 legacy conversion to `.blob`. | Key | Type | Default value | Description | | --------- | ---------------------------------------------------------------- | ------------- | ---------------------------------------- | @@ -554,6 +589,7 @@ Here you can specify options for tuning. > - `UploadCheckpoint` > - `ExportOnTrainEnd` > - `ArchiveOnTrainEnd` +> - `ConvertOnTrainEnd` > - `TestOnTrainEnd` ### Storage @@ -609,6 +645,7 @@ For more info on the variables, see [Credentials](../README.md#credentials). | `AWS_ACCESS_KEY_ID` | `str \| None` | `None` | | `AWS_SECRET_ACCESS_KEY` | `str \| None` | `None` | | `AWS_S3_ENDPOINT_URL` | `str \| None` | `None` | +| `HUBAI_API_KEY` | `str \| None` | `None` | | `MLFLOW_CLOUDFLARE_ID` | `str \| None` | `None` | | `MLFLOW_CLOUDFLARE_SECRET` | `str \| None` | `None` | | `MLFLOW_S3_BUCKET` | `str \| None` | `None` | diff --git a/configs/complex_model.yaml b/configs/complex_model.yaml index 6e513b76..d6a2f9d7 100644 --- a/configs/complex_model.yaml +++ b/configs/complex_model.yaml @@ -116,8 +116,7 @@ trainer: patience: 3 monitor: val/loss mode: min - - name: ExportOnTrainEnd - - name: ArchiveOnTrainEnd + - name: ConvertOnTrainEnd - name: TestOnTrainEnd optimizer: diff --git a/configs/detection_heavy_model.yaml b/configs/detection_heavy_model.yaml index 21eb3197..c8dd519a 100644 --- a/configs/detection_heavy_model.yaml +++ b/configs/detection_heavy_model.yaml @@ -46,8 +46,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd training_strategy: name: TripleLRSGDStrategy diff --git a/configs/detection_light_model.yaml b/configs/detection_light_model.yaml index 5b9def44..3d0fbd07 100644 --- a/configs/detection_light_model.yaml +++ b/configs/detection_light_model.yaml @@ -46,8 +46,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd training_strategy: name: "TripleLRSGDStrategy" diff --git a/configs/fomo_heavy_model.yaml b/configs/fomo_heavy_model.yaml index c2c63ea5..a6de2453 100644 --- a/configs/fomo_heavy_model.yaml +++ b/configs/fomo_heavy_model.yaml @@ -27,5 +27,4 @@ trainer: gradient_clip_val: 10 callbacks: - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd diff --git a/configs/fomo_light_model.yaml b/configs/fomo_light_model.yaml index 422602f9..9c07b1c0 100644 --- a/configs/fomo_light_model.yaml +++ b/configs/fomo_light_model.yaml @@ -27,5 +27,4 @@ trainer: gradient_clip_val: 10 callbacks: - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd diff --git a/configs/instance_segmentation_heavy_model.yaml b/configs/instance_segmentation_heavy_model.yaml index 50a7ef11..bd89c2b0 100644 --- a/configs/instance_segmentation_heavy_model.yaml +++ b/configs/instance_segmentation_heavy_model.yaml @@ -42,8 +42,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd - name: GradientAccumulationScheduler params: # warmup phase is 3 epochs diff --git a/configs/instance_segmentation_light_model.yaml b/configs/instance_segmentation_light_model.yaml index d4c8bd1d..b247e46c 100644 --- a/configs/instance_segmentation_light_model.yaml +++ b/configs/instance_segmentation_light_model.yaml @@ -42,8 +42,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd - name: GradientAccumulationScheduler params: scheduling: # warmup phase is 3 epochs diff --git a/configs/keypoint_bbox_heavy_model.yaml b/configs/keypoint_bbox_heavy_model.yaml index 18967d12..0c95683b 100644 --- a/configs/keypoint_bbox_heavy_model.yaml +++ b/configs/keypoint_bbox_heavy_model.yaml @@ -49,8 +49,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd # For best results, always accumulate gradients to # effectively use 64 batch size - name: GradientAccumulationScheduler diff --git a/configs/keypoint_bbox_light_model.yaml b/configs/keypoint_bbox_light_model.yaml index 0396a719..6340eb61 100644 --- a/configs/keypoint_bbox_light_model.yaml +++ b/configs/keypoint_bbox_light_model.yaml @@ -49,8 +49,7 @@ trainer: decay: 0.9999 use_dynamic_decay: True decay_tau: 2000 - - name: ExportOnTrainEnd - - name: TestOnTrainEnd + - name: ConvertOnTrainEnd # For best results, always accumulate gradients to # effectively use 64 batch size - name: GradientAccumulationScheduler diff --git a/configs/ocr_recognition_light_model.yaml b/configs/ocr_recognition_light_model.yaml index 81cc108d..45a90155 100644 --- a/configs/ocr_recognition_light_model.yaml +++ b/configs/ocr_recognition_light_model.yaml @@ -28,8 +28,7 @@ trainer: n_log_images: 8 callbacks: - - name: TestOnTrainEnd - - name: ExportOnTrainEnd + - name: ConvertOnTrainEnd optimizer: name: Adam diff --git a/configs/segmentation_heavy_model.yaml b/configs/segmentation_heavy_model.yaml index 8d3cf2c2..6d1186f2 100644 --- a/configs/segmentation_heavy_model.yaml +++ b/configs/segmentation_heavy_model.yaml @@ -24,8 +24,7 @@ trainer: n_log_images: 8 callbacks: - - name: TestOnTrainEnd - - name: ExportOnTrainEnd + - name: ConvertOnTrainEnd optimizer: name: SGD diff --git a/configs/segmentation_light_model.yaml b/configs/segmentation_light_model.yaml index c2e7a603..decb838e 100644 --- a/configs/segmentation_light_model.yaml +++ b/configs/segmentation_light_model.yaml @@ -25,8 +25,7 @@ trainer: n_log_images: 8 callbacks: - - name: TestOnTrainEnd - - name: ExportOnTrainEnd + - name: ConvertOnTrainEnd optimizer: name: SGD diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index 032531f3..01670893 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -401,6 +401,36 @@ def archive( ) +@app.command(group=export_group, sort_key=3) +def convert( + opts: list[str] | None = None, + /, + *, + config: str | None = None, + save_dir: str | None = None, + weights: str | None = None, +): + """Export, archive, and convert the model to target platform format. + + This is a unified command that combines export, archive, and + platform conversion (RVC2/RVC3/RVC4/Hailo) steps based on the + configuration. + + @type config: str + @param config: Path to the configuration file. + @type save_dir: str + @param save_dir: Directory where all outputs will be saved. If not + specified, the default run save directory will be used. + @type weights: str + @param weights: Path to the model weights. + @type opts: list[str] + @param opts: A list of optional CLI overrides of the config file. + """ + create_model(config, opts, weights=weights).convert( + weights=weights, save_dir=save_dir + ) + + @upgrade_app.command() def config( config: Annotated[ diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index dafcdd03..d6622db8 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -9,6 +9,7 @@ List of all supported callbacks. - [`PytorchLightning` Callbacks](#pytorchlightning-callbacks) - [`ExportOnTrainEnd`](#exportontrainend) - [`ArchiveOnTrainEnd`](#archiveontrainend) +- [`ConvertOnTrainEnd`](#convertontrainend) - [`MetadataLogger`](#metadatalogger) - [`TestOnTrainEnd`](#testontrainend) - [`UploadCheckpoint`](#uploadcheckpoint) @@ -51,6 +52,25 @@ Callback to create an `NN Archive` at the end of the training. | ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped | +## `ConvertOnTrainEnd` + +Unified callback that exports, archives, and converts the archive to the target platform at the end of training. This is the recommended callback for model conversion as it combines the functionality of `ExportOnTrainEnd` and `ArchiveOnTrainEnd`, and also runs platform-specific conversions (blobconverter or HubAI SDK) if configured. + +**Steps:** + +
    +
  1. Exports the model to ONNX
  2. +
  3. Creates an NN Archive from the ONNX
  4. +
  5. Runs blobconverter if `exporter.blobconverter.active` is `true`
  6. +
  7. Runs HubAI SDK conversion if `exporter.hubai.active` is `true`
  8. +
+ +**Parameters:** + +| Key | Type | Default value | Description | +| ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped | + ## `MetadataLogger` Callback that logs training metadata. diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index 2e8e607d..5aa5135d 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -12,6 +12,7 @@ from luxonis_train.registry import CALLBACKS from .archive_on_train_end import ArchiveOnTrainEnd +from .convert_on_train_end import ConvertOnTrainEnd from .ema import EMACallback from .export_on_train_end import ExportOnTrainEnd from .gpu_stats_monitor import GPUStatsMonitor @@ -43,11 +44,13 @@ CALLBACKS.register(module=TrainingManager) CALLBACKS.register(module=GracefulInterruptCallback) CALLBACKS.register(module=TrainingProgressCallback) +CALLBACKS.register(module=ConvertOnTrainEnd) __all__ = [ "ArchiveOnTrainEnd", "BaseLuxonisProgressBar", + "ConvertOnTrainEnd", "EMACallback", "ExportOnTrainEnd", "GPUStatsMonitor", diff --git a/luxonis_train/callbacks/convert_on_train_end.py b/luxonis_train/callbacks/convert_on_train_end.py new file mode 100644 index 00000000..89ba476c --- /dev/null +++ b/luxonis_train/callbacks/convert_on_train_end.py @@ -0,0 +1,28 @@ +import lightning.pytorch as pl +from loguru import logger + +import luxonis_train as lxt + +from .needs_checkpoint import NeedsCheckpoint + + +class ConvertOnTrainEnd(NeedsCheckpoint): + """Callback that exports, archives, and converts the model on train + end.""" + + def on_train_end( + self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" + ) -> None: + """Converts the model on train end. + + @type trainer: L{pl.Trainer} + @param trainer: Pytorch Lightning trainer. + @type pl_module: L{pl.LightningModule} + @param pl_module: Pytorch Lightning module. + """ + checkpoint = self.get_checkpoint(pl_module) + if checkpoint is None: # pragma: no cover + logger.warning("Skipping model conversion.") + return + + pl_module.core.convert(weights=checkpoint) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 3e0fcd83..6b6a2863 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -559,6 +559,35 @@ def reorder_callbacks(self) -> Self: self.callbacks.sort(key=lambda v: 0 if v.name == "EMACallback" else 1) return self + @model_validator(mode="after") + def check_convert_callbacks(self) -> Self: + callback_names = {cb.name for cb in self.callbacks if cb.active} + has_convert = "ConvertOnTrainEnd" in callback_names + has_export = "ExportOnTrainEnd" in callback_names + has_archive = "ArchiveOnTrainEnd" in callback_names + + if has_convert and (has_export or has_archive): + redundant = [] + for cb in self.callbacks: + if ( + cb.name in ("ExportOnTrainEnd", "ArchiveOnTrainEnd") + and cb.active + ): + cb.active = False + redundant.append(cb.name) + if redundant: + logger.warning( + f"Deactivated {redundant} because 'ConvertOnTrainEnd' is active " + "and already includes export and archive functionality." + ) + elif has_export and has_archive: + logger.warning( + "Both 'ExportOnTrainEnd' and 'ArchiveOnTrainEnd' callbacks are set. " + "Consider using 'ConvertOnTrainEnd' instead, which combines both " + "and also handles platform-specific conversions (blobconverter/HubAI SDK)." + ) + return self + class OnnxExportConfig(BaseModelExtraForbid): opset_version: PositiveInt = 16 @@ -575,6 +604,22 @@ class BlobconverterExportConfig(BaseModelExtraForbid): ) +class HubAIExportConfig(BaseModelExtraForbid): + active: bool = False + platform: Literal["rvc2", "rvc3", "rvc4", "hailo"] | None = None + params: Params = Field(default_factory=dict) + delete_remote_model: bool = False + + @model_validator(mode="after") + def validate_platform(self) -> Self: + if self.active and self.platform is None: + raise ValueError( + "The `platform` field is required when `hubai.active` is True. " + "Please specify a target platform: 'rvc2', 'rvc3', 'rvc4', or 'hailo'." + ) + return self + + class ArchiveConfig(BaseModelExtraForbid): name: str | None = None upload_to_run: bool = True @@ -584,7 +629,10 @@ class ArchiveConfig(BaseModelExtraForbid): class ExportConfig(ArchiveConfig): name: str | None = None input_shape: list[int] | None = None - data_type: Literal["int8", "fp16", "fp32"] = "fp16" + target_precision: Annotated[ + Literal["int8", "fp16", "fp32"], + Field(validation_alias=AliasChoices("target_precision", "data_type")), + ] = "fp16" reverse_input_channels: bool | None = None scale_values: list[float] | None = None mean_values: list[float] | None = None @@ -592,6 +640,7 @@ class ExportConfig(ArchiveConfig): blobconverter: BlobconverterExportConfig = Field( default_factory=BlobconverterExportConfig ) + hubai: HubAIExportConfig = Field(default_factory=HubAIExportConfig) @field_validator("scale_values", "mean_values", mode="before") @classmethod @@ -862,8 +911,7 @@ def smart_auto_populate(self) -> Self: default_callbacks = [ "UploadCheckpoint", "TestOnTrainEnd", - "ExportOnTrainEnd", - "ArchiveOnTrainEnd", + "ConvertOnTrainEnd", ] for cb_name in default_callbacks: diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 8a62a4d6..1b8011c0 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -50,6 +50,7 @@ from .utils.export_utils import ( blobconverter_export, get_preprocessing, + hubai_export, make_initializers_unique, replace_weights, try_onnx_simplify, @@ -407,10 +408,6 @@ def export( This is useful for updating the metadata in the checkpoint file in case they changed (e.g. new configuration file, architectural changes affecting the exection order etc.) - @type unique_onnx_initializers: bool - @param unique_onnx_initializers: If True, a single pass through the - onnx model is done after export to ensure that identifiers are unique. - @raises RuntimeError: If C{onnxsim} fails to simplify the model. """ weights = weights or self.cfg.model.weights @@ -471,34 +468,6 @@ def export( ) scale_values = self.cfg.exporter.scale_values or scale mean_values = self.cfg.exporter.mean_values or mean - if self.cfg.exporter.reverse_input_channels is not None: - reverse_input_channels = self.cfg.exporter.reverse_input_channels - else: - logger.info( - "`exporter.reverse_input_channels` not specified. " - "Using the `trainer.preprocessing.color_space` value " - "to determine if the channels should be reversed. " - f"`color_space` = '{color_space}' -> " - f"`reverse_input_channels` = `{color_space == 'RGB'}`" - ) - reverse_input_channels = color_space == "RGB" - - if self.cfg.exporter.blobconverter.active: - try: - self._exported_models["blob"] = blobconverter_export( - self.cfg.exporter, - scale_values, - mean_values, - reverse_input_channels, - str(export_save_dir), - onnx_save_path, - ) - except ImportError: - logger.error("Failed to import `blobconverter`") - logger.warning( - "`blobconverter` not installed. Skipping .blob model conversion. " - "Ensure `blobconverter` is installed in your environment." - ) for path in self._exported_models.values(): if self.cfg.exporter.upload_to_run: @@ -1046,6 +1015,124 @@ def _archive( return Path(archive_path) + def convert( + self, + weights: PathType | None = None, + save_dir: PathType | None = None, + ) -> Path: + """Exports the model to ONNX, creates an NN Archive, and + converts to target platform format (RVC2/RVC3/RVC4/Hailo). + + This is a unified method that combines export, archive, and platform + conversion steps. + + @type weights: PathType | None + @param weights: Path to the checkpoint from which to load weights. + If not specified, the value of `model.weights` from the + configuration file will be used. + @type save_dir: PathType | None + @param save_dir: Directory where the outputs will be saved. + If not specified, the default run save directory will be used. + @rtype: Path + @return: Path to the generated NN Archive. + """ + self.export(weights=weights, save_path=save_dir) + + onnx_path = self._exported_models.get("onnx") + if onnx_path is None: + raise RuntimeError( + "ONNX export failed, cannot proceed with conversion." + ) + + archive_path = self.archive( + path=onnx_path, weights=weights, save_dir=save_dir + ) + + mean, scale, color_space = get_preprocessing( + self.cfg_preprocessing, "Model conversion" + ) + scale_values = self.cfg.exporter.scale_values or scale + mean_values = self.cfg.exporter.mean_values or mean + if self.cfg.exporter.reverse_input_channels is not None: + reverse_input_channels = self.cfg.exporter.reverse_input_channels + else: + logger.info( + "`exporter.reverse_input_channels` not specified. " + "Using the `trainer.preprocessing.color_space` value " + "to determine if the channels should be reversed. " + f"`color_space` = '{color_space}' -> " + f"`reverse_input_channels` = `{color_space == 'RGB'}`" + ) + reverse_input_channels = color_space == "RGB" + + convert_save_dir = ( + Path(save_dir) if save_dir else Path(self.run_save_dir) + ) + + if self.cfg.exporter.blobconverter.active: + logger.warning( + "blobconverter is deprecated and only supports RVC2 legacy conversion to `.blob`." + "Please consider using the HubAI SDK instead." + ) + try: + blob_path = blobconverter_export( + self.cfg.exporter, + scale_values, + mean_values, + reverse_input_channels, + str(convert_save_dir), + str(onnx_path), + ) + self._exported_models["blob"] = blob_path + if self.cfg.exporter.upload_to_run: + self.tracker.upload_artifact(blob_path, typ="export") + if self.cfg.exporter.upload_url is not None: + LuxonisFileSystem.upload( + blob_path, self.cfg.exporter.upload_url + ) + except ImportError: + logger.error("Failed to import `blobconverter`") + logger.warning( + "`blobconverter` not installed. Skipping .blob model conversion. " + "Ensure `blobconverter` is installed in your environment." + ) + + if self.cfg.exporter.hubai.active: + try: + dataset_name = None + if "train" in self.loaders and hasattr( + self.loaders["train"], "dataset" + ): + dataset = getattr(self.loaders["train"], "dataset", None) + if dataset is not None: + dataset_name = getattr(dataset, "dataset_name", None) + hubai_archive_path = hubai_export( + cfg=self.cfg.exporter.hubai, + target_precision=self.cfg.exporter.target_precision, + archive_path=archive_path, + export_path=convert_save_dir, + model_name=self.cfg.model.name, + dataset_name=dataset_name, + ) + if self.cfg.archiver.upload_to_run: + self.tracker.upload_artifact( + hubai_archive_path, typ="archive" + ) + if self.cfg.archiver.upload_url is not None: + LuxonisFileSystem.upload( + hubai_archive_path, self.cfg.archiver.upload_url + ) + except ImportError: + logger.error("Failed to import `hubai_sdk`") + logger.warning( + "`hubai-sdk` not installed. Skipping HubAI conversion. " + "Ensure `hubai-sdk` is installed in your environment." + ) + except ValueError as e: + raise ValueError(f"HubAI conversion failed: {e}") from e + + return archive_path + @property def environ(self) -> Environ: return self.cfg.ENVIRON diff --git a/luxonis_train/core/utils/export_utils.py b/luxonis_train/core/utils/export_utils.py index 4533358f..3e819108 100644 --- a/luxonis_train/core/utils/export_utils.py +++ b/luxonis_train/core/utils/export_utils.py @@ -1,5 +1,7 @@ +import os +import shutil from collections.abc import Generator -from contextlib import contextmanager +from contextlib import contextmanager, suppress from pathlib import Path from typing import Literal @@ -8,7 +10,7 @@ import luxonis_train as lxt from luxonis_train.config import ExportConfig -from luxonis_train.config.config import PreprocessingConfig +from luxonis_train.config.config import HubAIExportConfig, PreprocessingConfig @contextmanager @@ -107,7 +109,7 @@ def blobconverter_export( blob_path = blobconverter.from_onnx( model=str(onnx_path), optimizer_params=optimizer_params, - data_type=cfg.data_type.upper(), + data_type=cfg.target_precision.upper(), shaves=cfg.blobconverter.shaves, version=cfg.blobconverter.version, use_cache=False, @@ -117,6 +119,151 @@ def blobconverter_export( return Path(blob_path) +def hubai_export( + cfg: HubAIExportConfig, + target_precision: str, + archive_path: PathType, + export_path: PathType, + model_name: str, + dataset_name: str | None = None, +) -> Path: + """Convert an ONNX NNArchive to a platform-specific NNArchive using + HubAI SDK. + + If a model with the given name already exists on HubAI, a new + variant will be created under that model. Otherwise, a new model + will be created. + + @type cfg: HubAIExportConfig + @param cfg: HubAI export configuration containing platform and + params. + @type target_precision: str + @param target_precision: Target precision (int8, fp16, fp32). + @type archive_path: PathType + @param archive_path: Path to the ONNX NNArchive to convert. + @type export_path: PathType + @param export_path: Directory where the converted archive will be + saved. + @type model_name: str + @param model_name: Name for the model on HubAI. + @type dataset_name: str | None + @param dataset_name: Name of the dataset the model was trained on. + @rtype: Path + @return: Path to the converted platform-specific NNArchive. + """ + from hubai_sdk import HubAIClient + + hubai_token = os.environ.get("HUBAI_API_KEY") + if not hubai_token: + raise ValueError( + "HUBAI_API_KEY environment variable is not set. " + "Please set it to use HubAI SDK for model conversion. " + ) + + precision_map = { + "int8": "INT8", + "fp16": "FP16", + "fp32": "FP32", + } + precision = precision_map.get(target_precision.lower(), "FP16") + + client = HubAIClient(api_key=hubai_token) + archive_path = Path(archive_path) + + existing_model = None + created_new_model = False + try: + models = client.models.list_models() + if models: + existing_model = next( + (m for m in models if m.name == model_name), None + ) + except Exception as e: + logger.warning(f"Failed to check for existing model: {e}") + + variant_name = ( + f"{model_name}:{dataset_name}" if dataset_name else f"{model_name}" + ) + + base_kwargs: dict = { + "path": str(archive_path), + "target_precision": precision, + "name": variant_name, + } + + if existing_model: + base_kwargs["model_id"] = str(existing_model.id) + logger.info( + f"Model '{model_name}' already exists on HubAI. " + f"Creating new variant '{variant_name}' under existing model." + ) + else: + new_model = client.models.create_model(model_name, silent=True) + base_kwargs["model_id"] = str(new_model.id) + created_new_model = True + logger.info( + f"Created new model '{model_name}' on HubAI. " + f"Creating variant '{variant_name}' under it." + ) + + if cfg.params: + base_kwargs.update(cfg.params) + + variant_id = None + + try: + if cfg.platform == "hailo": + if "quantization_data" not in base_kwargs: + base_kwargs["quantization_data"] = "RANDOM" + response = client.convert.Hailo(**base_kwargs) + elif cfg.platform == "rvc3": + response = client.convert.RVC3(**base_kwargs) + elif cfg.platform == "rvc4": + response = client.convert.RVC4(**base_kwargs) + else: + response = client.convert.RVC2(**base_kwargs) + + if hasattr(response, "instance") and hasattr( + response.instance, "model_version_id" + ): + variant_id = str(response.instance.model_version_id) + downloaded_path = Path(response.downloaded_path) + + export_path = Path(export_path) + output_path = export_path / downloaded_path.name + + downloaded_parent = downloaded_path.parent + + shutil.move(downloaded_path, output_path) + + if downloaded_parent.exists() and downloaded_parent != Path.cwd(): + with suppress(OSError): + downloaded_parent.rmdir() + + logger.info(f"HubAI converted archive saved to {output_path}") + return output_path + finally: + if cfg.delete_remote_model: + try: + if created_new_model: + client.models.delete_model(model_name) + logger.debug( + f"Cleaned up temporary HubAI model: {model_name}" + ) + elif variant_id: + client.variants.delete_variant(variant_id) + logger.debug( + f"Cleaned up temporary HubAI variant: {variant_id}" + ) + except Exception as e: + resource_type = "model" if created_new_model else "variant" + resource_id = model_name if created_new_model else variant_id + logger.warning( + f"Failed to cleanup HubAI {resource_type} " + f"'{resource_id}': {e}" + ) + + def make_initializers_unique(onnx_path: PathType) -> None: """Each initializer that is used by multiple nodes gets duplicated so each node has its own copy. diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index c5c7fe8d..6c8abfb3 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -887,6 +887,13 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]: artifact_keys.add( f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz" ) + elif callback.name == "ConvertOnTrainEnd": + artifact_keys.add( + f"{self.cfg.exporter.name or self.cfg.model.name}.onnx" + ) + artifact_keys.add( + f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz" + ) elif callback.name == "TrainingProgressCallback": metric_keys.update( { diff --git a/requirements.txt b/requirements.txt index c800af1f..4ad224d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ tensorboard~=2.20 termcolor~=3.2 torchmetrics~=1.8 torchvision~=0.24 +hubai-sdk diff --git a/tests/conftest.py b/tests/conftest.py index 915b3968..e0024122 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -375,6 +375,7 @@ def opts(save_dir: Path, image_size: tuple[int, int]) -> Params: {"name": "TestOnTrainEnd", "active": False}, {"name": "ExportOnTrainEnd", "active": False}, {"name": "ArchiveOnTrainEnd", "active": False}, + {"name": "ConvertOnTrainEnd", "active": False}, {"name": "UploadCheckpoint", "active": False}, ], "tracker.save_directory": str(save_dir), diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index 6adaff92..8f2d9600 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -37,6 +37,10 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "name": "ArchiveOnTrainEnd", "params": {"preferred_checkpoint": "loss"}, }, + { + "name": "ConvertOnTrainEnd", + "params": {"preferred_checkpoint": "loss"}, + }, { "name": "GradCamCallback", "params": { diff --git a/tests/integration/test_cli_commands.py b/tests/integration/test_cli_commands.py index 8c2b0d08..710c7e12 100644 --- a/tests/integration/test_cli_commands.py +++ b/tests/integration/test_cli_commands.py @@ -15,6 +15,7 @@ from luxonis_train.__main__ import ( _yield_visualizations, archive, + convert, export, inspect, train, @@ -42,6 +43,7 @@ def test_cli_command_success( (export, {"save_path": tmp_path}), (_yield_visualizations, {}), (archive, {"executable": tmp_path / "export.onnx"}), + (convert, {"save_dir": tmp_path / "convert_output"}), ]: with subtests.test(command.__name__): res = command( @@ -97,6 +99,7 @@ def test_cli_command_success( }, ), (archive, {"config": "nonexistent.yaml"}), + (convert, {"config": "nonexistent.yaml"}), ], ) def test_cli_command_failure( diff --git a/tests/integration/test_convert.py b/tests/integration/test_convert.py new file mode 100644 index 00000000..c64dbaa6 --- /dev/null +++ b/tests/integration/test_convert.py @@ -0,0 +1,103 @@ +from pathlib import Path + +import pytest +from luxonis_ml.data import LuxonisDataset +from luxonis_ml.typing import Params + +from luxonis_train.core import LuxonisModel + + +def test_convert_basic( + coco_dataset: LuxonisDataset, opts: Params, tmp_path: Path +): + """Export + archive, without blobconverter or hubai exporter + defined.""" + config_file = "configs/detection_light_model.yaml" + opts |= { + "loader.params.dataset_name": coco_dataset.identifier, + "model.name": "test_convert_basic", + "exporter.blobconverter.active": False, + "exporter.hubai.active": False, + } + model = LuxonisModel(config_file, opts) + + save_dir = tmp_path / "convert_output" + archive_path = model.convert(save_dir=save_dir) + + assert archive_path.exists(), "Archive was not created" + assert archive_path.suffix == ".xz", "Archive should be a .xz file" + + onnx_path = model._exported_models.get("onnx") + assert onnx_path is not None, "ONNX model was not exported" + assert Path(onnx_path).exists(), "ONNX file does not exist" + + +def test_convert_with_blobconverter( + coco_dataset: LuxonisDataset, opts: Params, tmp_path: Path +): + config_file = "configs/detection_light_model.yaml" + opts |= { + "loader.params.dataset_name": coco_dataset.identifier, + "model.name": "test_convert_blob", + "exporter.blobconverter.active": True, + "exporter.hubai.active": False, + "exporter.scale_values": [255.0, 255.0, 255.0], + "exporter.mean_values": [127.5, 127.5, 127.5], + } + model = LuxonisModel(config_file, opts) + + save_dir = tmp_path / "convert_blob_output" + archive_path = model.convert(save_dir=save_dir) + + assert archive_path.exists(), "Archive was not created" + + blob_path = model._exported_models.get("blob") + assert blob_path is not None, "Blob model was not created" + assert Path(blob_path).exists(), "Blob file does not exist" + assert Path(blob_path).suffix == ".blob", ( + "Blob file should have .blob extension" + ) + + +@pytest.mark.parametrize("platform", ["rvc2", "rvc3", "rvc4"]) +def test_convert_with_hubai( + coco_dataset: LuxonisDataset, opts: Params, tmp_path: Path, platform: str +): + config_file = "configs/detection_light_model.yaml" + opts |= { + "loader.params.dataset_name": coco_dataset.identifier, + "model.name": f"test_convert_hubai_{platform}", + "exporter.blobconverter.active": False, + "exporter.hubai.active": True, + "exporter.hubai.platform": platform, + } + model = LuxonisModel(config_file, opts) + + save_dir = tmp_path / f"convert_hubai_{platform}_output" + archive_path = model.convert(save_dir=save_dir) + + assert archive_path.exists(), "Archive was not created" + + platform_archives = list(save_dir.glob("*.tar.xz")) + assert len(platform_archives) > 0, ( + f"No platform-specific archive created for {platform}" + ) + + +def test_convert_saves_to_default_directory( + coco_dataset: LuxonisDataset, opts: Params +): + """Test that convert uses default save directory when not + specified.""" + config_file = "configs/detection_light_model.yaml" + opts |= { + "loader.params.dataset_name": coco_dataset.identifier, + "model.name": "test_convert_default_dir", + "exporter.blobconverter.active": False, + "exporter.hubai.active": False, + } + model = LuxonisModel(config_file, opts) + + archive_path = model.convert() + + assert archive_path.exists(), "Archive was not created" diff --git a/tests/integration/test_mlflow_logging.py b/tests/integration/test_mlflow_logging.py index 9dcd0c54..7e0ea4cf 100644 --- a/tests/integration/test_mlflow_logging.py +++ b/tests/integration/test_mlflow_logging.py @@ -232,8 +232,7 @@ def get_config() -> Params: }, "callbacks": [ {"name": "TestOnTrainEnd"}, - {"name": "ExportOnTrainEnd"}, - {"name": "ArchiveOnTrainEnd"}, + {"name": "ConvertOnTrainEnd"}, {"name": "UploadCheckpoint"}, {"name": "DeviceStatsMonitor"}, {"name": "TrainingProgressCallback"}, diff --git a/tests/unittests/test_config.py b/tests/unittests/test_config.py index c406ee77..31fa0745 100644 --- a/tests/unittests/test_config.py +++ b/tests/unittests/test_config.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import pytest from luxonis_ml.data import LuxonisDataset @@ -124,3 +125,78 @@ def test_config_invalid(): } with pytest.raises(ValueError, match="Only one main metric"): Config.get_config(cfg) + + +@pytest.mark.parametrize( + ("callbacks_input", "expected_active"), + [ + ( + [{"name": "ConvertOnTrainEnd"}], + {"ConvertOnTrainEnd": True}, + ), + ( + [{"name": "ExportOnTrainEnd"}], + {"ExportOnTrainEnd": True}, + ), + ( + [{"name": "ArchiveOnTrainEnd"}], + {"ArchiveOnTrainEnd": True}, + ), + ( + [{"name": "ExportOnTrainEnd"}, {"name": "ArchiveOnTrainEnd"}], + {"ExportOnTrainEnd": True, "ArchiveOnTrainEnd": True}, + ), + ( + [{"name": "ConvertOnTrainEnd"}, {"name": "ExportOnTrainEnd"}], + {"ConvertOnTrainEnd": True, "ExportOnTrainEnd": False}, + ), + ( + [{"name": "ConvertOnTrainEnd"}, {"name": "ArchiveOnTrainEnd"}], + {"ConvertOnTrainEnd": True, "ArchiveOnTrainEnd": False}, + ), + ( + [ + {"name": "ConvertOnTrainEnd"}, + {"name": "ExportOnTrainEnd"}, + {"name": "ArchiveOnTrainEnd"}, + ], + { + "ConvertOnTrainEnd": True, + "ExportOnTrainEnd": False, + "ArchiveOnTrainEnd": False, + }, + ), + ( + [ + {"name": "ConvertOnTrainEnd"}, + {"name": "ExportOnTrainEnd", "active": False}, + ], + {"ConvertOnTrainEnd": True, "ExportOnTrainEnd": False}, + ), + ], +) +def test_convert_callback_deactivates_export_and_archive( + callbacks_input: list[Params], expected_active: dict[str, bool] +): + """Test that ConvertOnTrainEnd deactivates ExportOnTrainEnd and + ArchiveOnTrainEnd.""" + cfg = Config.get_config( + cast( + Params, + { + "model": {"nodes": [{"name": "ResNet"}]}, + "trainer": {"callbacks": callbacks_input}, + }, + ) + ) + + callbacks_by_name = {cb.name: cb for cb in cfg.trainer.callbacks} + + for callback_name, should_be_active in expected_active.items(): + assert callback_name in callbacks_by_name, ( + f"Callback '{callback_name}' not found in config" + ) + assert callbacks_by_name[callback_name].active == should_be_active, ( + f"Callback '{callback_name}' expected active={should_be_active}, " + f"got active={callbacks_by_name[callback_name].active}" + )