diff --git a/collections/nemo_cv/nemo_cv/modules/README_RetinaNet.md b/collections/nemo_cv/nemo_cv/modules/README_RetinaNet.md new file mode 100644 index 000000000000..1bf3af553029 --- /dev/null +++ b/collections/nemo_cv/nemo_cv/modules/README_RetinaNet.md @@ -0,0 +1,20 @@ + +# The docker image that is tested working with RetinaNet is nvcr.io/nvidia/pytorch:19.09-py3 + +# To run RetinaNet Nemo Wrapper follow the below steps: + +docker run -it --rm --ipc=host --gpus all -v {path to the project}:/workspace nvcr.io/nvidia/pytorch:19.09-py3 + +pip install --no-cache-dir git+https://github.com/nvidia/retinanet-examples + +git clone https://github.com/NVIDIA/NeMo.git + +cd NeMo/nemo +# Change the version of Pytorch in Nemo/nemo/setup.py to 'torch==1.2.0' + +python setup install + +cd .. +cd collections/nemo_cv/nemo_cv/modules/ + +python pascal_retinanet.py train model_mydataset.pth --backbone ResNet18FPN --classes 20 --iters 10000 --val-iters 1000 --lr 0.0005 --resize 512 --jitter 480 640 --images /workspace/PASCAL_VOC/JPEGImages/ --annotations /workspace/PASCAL_VOC/pascal_train2012.json --val-annotations /workspace/PASCAL_VOC/pascal_val2012.json diff --git a/collections/nemo_cv/nemo_cv/modules/pascal_retinanet.py b/collections/nemo_cv/nemo_cv/modules/pascal_retinanet.py new file mode 100644 index 000000000000..0011bbe233d8 --- /dev/null +++ b/collections/nemo_cv/nemo_cv/modules/pascal_retinanet.py @@ -0,0 +1,31 @@ +# Copyright (C) , NVIDIA INC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Hussein Al-barazanchi" + + +import sys +from retinanet_module import * + + +def main(args=None): + 'Entry point for the retinanet command' + + args = parse(args or sys.argv[1:]) + + detector = RetinaNet(args) + detector.execute(args) + +if __name__ == '__main__': + main() diff --git a/collections/nemo_cv/nemo_cv/modules/retinanet_module.py b/collections/nemo_cv/nemo_cv/modules/retinanet_module.py new file mode 100644 index 000000000000..408e0e597752 --- /dev/null +++ b/collections/nemo_cv/nemo_cv/modules/retinanet_module.py @@ -0,0 +1,61 @@ +# Copyright (C) , NVIDIA INC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Hussein Al-barazanchi" + + +import torch + +from retinanet.model import Model +from retinanet.main import parse, load_model, worker + +from nemo.backends.pytorch.nm import TrainableNM + + +class RetinaNet(TrainableNM): + """ + Wrapper class around the RetinaNet model. + """ + + @staticmethod + def create_ports(): + + return None, None + + def __init__(self, args): + """ + Creates the RetinaNet model. + + Args: + num_classes: Number of output classes of the model. + pretrained: use weights of model pretrained on COCO train2017. + """ + + super().__init__() + + # Create + self.model, self.state = load_model(args, verbose=True) + if self.model: + self.model.share_memory() + + def forward(self, images, bounding_boxes, targets): + pass + + def execute(self, args): + + world = torch.cuda.device_count() + if args.command == 'export' or world <= 1: + worker(0, args, 1, self.model, self.state) + else: + torch.multiprocessing.spawn(worker, args=(args, world, self.model, self.state), nprocs=world)