-
Notifications
You must be signed in to change notification settings - Fork 145
Closed
Description
When running with MPS, there is a compatibility issue that FastSurfer was using torch.float16 for the prediction tensor, but MPS (Metal Performance Shaders) on Apple Silicon has limited support for float16 operations, particularly for the add_ operation with alpha scaling. This fix was already applied to HypVINN, but we still need it in FastSurfer. When MPS, we need to use torch.float32.
Below is a diff if it helps... I would like to fork the repo and bring this as PR, but I couldn't do it yet.
diff --git forkSrcPrefix/FastSurferCNN/run_prediction.py forkDstPrefix/FastSurferCNN/run_prediction.py
index 277c3233fe2da41c5a91aac96653ad4c339bed97..d4c063e75a135e43a05b790a39814f056d440933 100644
--- forkSrcPrefix/FastSurferCNN/run_prediction.py
+++ forkDstPrefix/FastSurferCNN/run_prediction.py
@@ -387,9 +387,11 @@ class RunModelOnData:
Predicted classes.
"""
shape = orig_data.shape + (self.get_num_classes(),)
+ # Use float32 for MPS devices due to limited float16 support
+ dtype = torch.float32 if self.viewagg_device.type == "mps" else torch.float16
kwargs = {
"device": self.viewagg_device,
- "dtype": torch.float16,
+ "dtype": dtype,
"requires_grad": False,
}
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels