Skip to content

Commit c1b3780

Browse files
committed
Fixes missing DataParallel in rs train
1 parent 344c0bb commit c1b3780

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

robosat/tools/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def main(args):
6363
os.makedirs(model["common"]["checkpoint"], exist_ok=True)
6464

6565
num_classes = len(dataset["common"]["classes"])
66-
net = FPNSegmentation(num_classes).to(device)
66+
net = FPNSegmentation(num_classes)
67+
net = DataParallel(net)
68+
net = net.to(device)
6769

6870
if model["common"]["cuda"]:
6971
torch.backends.cudnn.benchmark = True

0 commit comments

Comments
 (0)