From 011efd8438800fe50ac85f2f1ad0c37e0354d57a Mon Sep 17 00:00:00 2001 From: Artur Mostowski Date: Sat, 24 Jul 2021 17:06:23 +0200 Subject: [PATCH] [FIX] change labels type from int32 to long --- training/standard_callbacks.py | 2 +- training/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/training/standard_callbacks.py b/training/standard_callbacks.py index 3657d6b5..73767bee 100644 --- a/training/standard_callbacks.py +++ b/training/standard_callbacks.py @@ -54,7 +54,7 @@ def correct(labels, outputs): with torch.no_grad(): for examples, labels in loader: examples = examples.to(get_platform().torch_device) - labels = labels.squeeze().to(get_platform().torch_device) + labels = labels.squeeze().to(get_platform().torch_device).long() output = model(examples) labels_size = torch.tensor(len(labels), device=get_platform().torch_device) diff --git a/training/train.py b/training/train.py index 99951b46..e3055d67 100644 --- a/training/train.py +++ b/training/train.py @@ -111,7 +111,7 @@ def train( # Otherwise, train. examples = examples.to(device=get_platform().torch_device) - labels = labels.to(device=get_platform().torch_device) + labels = labels.to(device=get_platform().torch_device).long() step_optimizer.zero_grad() model.train()