diff --git a/utils.py b/utils.py index a83b0fd6..3abd5b2f 100644 --- a/utils.py +++ b/utils.py @@ -74,7 +74,7 @@ def __init__(self, batches, batch_size, device): self.batches = batches self.n_batches = len(batches) // batch_size self.residue = False # 记录batch数量是否为整数 - if len(batches) % self.n_batches != 0: + if len(batches) % batch_size != 0: self.residue = True self.index = 0 self.device = device