diff --git a/KernelBench/level1/94_MSELoss.py b/KernelBench/level1/94_MSELoss.py index 2dc77eed..087b700c 100644 --- a/KernelBench/level1/94_MSELoss.py +++ b/KernelBench/level1/94_MSELoss.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from torch.distributions import Pareto class Model(nn.Module): """ @@ -19,8 +20,9 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - scale = torch.rand(()) - return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + return [predictions, targets] def get_init_inputs(): return []