Skip to content

run_prediction.py fails with MPS (with diff solution) #744

@oytuntez

Description

@oytuntez

When running with MPS, there is a compatibility issue that FastSurfer was using torch.float16 for the prediction tensor, but MPS (Metal Performance Shaders) on Apple Silicon has limited support for float16 operations, particularly for the add_ operation with alpha scaling. This fix was already applied to HypVINN, but we still need it in FastSurfer. When MPS, we need to use torch.float32.

Below is a diff if it helps... I would like to fork the repo and bring this as PR, but I couldn't do it yet.

diff --git forkSrcPrefix/FastSurferCNN/run_prediction.py forkDstPrefix/FastSurferCNN/run_prediction.py
index 277c3233fe2da41c5a91aac96653ad4c339bed97..d4c063e75a135e43a05b790a39814f056d440933 100644
--- forkSrcPrefix/FastSurferCNN/run_prediction.py
+++ forkDstPrefix/FastSurferCNN/run_prediction.py
@@ -387,9 +387,11 @@ class RunModelOnData:
             Predicted classes.
         """
         shape = orig_data.shape + (self.get_num_classes(),)
+        # Use float32 for MPS devices due to limited float16 support
+        dtype = torch.float32 if self.viewagg_device.type == "mps" else torch.float16
         kwargs = {
             "device": self.viewagg_device,
-            "dtype": torch.float16,
+            "dtype": dtype,
             "requires_grad": False,
         }
 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions