diff --git a/training/standard_callbacks.py b/training/standard_callbacks.py index 3657d6b5..73767bee 100644 --- a/training/standard_callbacks.py +++ b/training/standard_callbacks.py @@ -54,7 +54,7 @@ def correct(labels, outputs): with torch.no_grad(): for examples, labels in loader: examples = examples.to(get_platform().torch_device) - labels = labels.squeeze().to(get_platform().torch_device) + labels = labels.squeeze().to(get_platform().torch_device).long() output = model(examples) labels_size = torch.tensor(len(labels), device=get_platform().torch_device) diff --git a/training/train.py b/training/train.py index 99951b46..e3055d67 100644 --- a/training/train.py +++ b/training/train.py @@ -111,7 +111,7 @@ def train( # Otherwise, train. examples = examples.to(device=get_platform().torch_device) - labels = labels.to(device=get_platform().torch_device) + labels = labels.to(device=get_platform().torch_device).long() step_optimizer.zero_grad() model.train()