diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 07bedf0..431c2eb 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -187,7 +187,7 @@ def main(kwargs_dict: dict = {}): shard_dim=shard_dim, device_type=device.type, ) - model = model.to(device).to(memory_format=torch.contiguous_format) + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP( model,