Skip to content

Conversation

@naveenkcb
Copy link
Contributor

Contributor: Naveen Baskaran

Contribution Type: Interpretability method, Tests, Example

Description

This PR implements the SHAP (SHapley Additive exPlanations) interpretability method for PyHealth models, enabling users to understand which features contribute most to model predictions. SHAP is based on coalitional game theory and provides theoretically grounded feature importance scores with desirable properties like local accuracy, missingness, and consistency.

Files to Review

pyhealth/interpret/methods/init.py
pyhealth/interpret/methods/shap.py - Core SHAP method implementation. Suports embedding based attribution, continuous feature support
pyhealth/processors/tensor_processor.py - minor fix to resolve warning message
examples/shap_stagenet_mimic4.py - Example script showing the usage of SHAP method
tests/core/test_shap.py - added comprehensive test cases to test the main class, utility methods and attribution methods.

Results on mimic4-demo dataset

image

# Initialize SHAP explainer with custom parameters
shap_explainer = ShapExplainer(
model,
use_embeddings=True, # Use embeddings for discrete features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SHAP should be compatible with discrete tokens like ICD codes here? Correct me if I'm wrong. Will look deeper into understanding the full implementation of SHAP here later when I'm more congitively sound.

Copy link
Contributor Author

@naveenkcb naveenkcb Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the SHAPExplainer class instance creation to pass the just the model and default all other values including "use_embeddings" inside the init method. Yes the SHAP works for ICD codes but will use the embeddings from the input model.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other nice to haves:

  1. Can we add an entry to docs/api/interpret/shap.rst ? And add its entry in the interpretability index interpret.rst?
  2. Can we check that this is compatible when the device is on GPU? Maybe, through a colab notebook? (There's a way to install the branch/repo to the colab environment)
  3. I might be able to share some compute resources soon once NCSA gets back to me.

if coalition_size == 0 or coalition_size == n_features:
return torch.tensor(1000.0) # Large weight for edge cases

comb_val = math.comb(n_features - 1, coalition_size - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, isn't it binom(M, |z|) here? Why do we take n_features -1 and coalition-size -1 instead of n_features and coalition_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to match the equation used in the original SHAP paper in the method _compute_kernel_weight

weight = (M - 1) / (binom(M, |z|) * |z| * (M - |z|))

  1. I also added the .rst file as requested.
  2. Added "examples/shap_stagenet_mimic4.ipynb" using colab with GPU

coalition_vectors = []
coalition_weights = []
coalition_preds = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if we don't need to add the edge case coalitions specifically (full features and no features) in the prediction set for training the kernel/linear model for predicting shapley values here?

I've linked some captum code examples here:

https://github.com/meta-pytorch/captum/blob/master/captum/attr/_core/kernel_shap.pyhttps://github.com/meta-pytorch/captum/blob/master/captum/attr/_core/lime.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handled the edge cases and updated the code accordingly in the method _compute_kernel_shap

)

# Sample remaining coalitions randomly (excluding edge cases already added)
n_random_coalitions = max(0, n_coalitions - 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey one last request, as I've been digging deeper into the official shap implementation of kernel regression. They of course do it in numpy, which is mostly so people can interpret random forests and XGBoost.

But, I think we can adopt some of the nice tricks in coalition sampling here:

https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Specifically, it seems they do some type of compliment sampling to optimize how much coverage of samples we can sample at a time.

https://github.com/shap/shap/blob/ace49bf463a802f18725a869a969c060a192e3f8/shap/explainers/_kernel.py#L480

Let me know if you need any help with this. It does look a little complicated, but everything else looks good to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhnwu3 based on our discord convo, I am keeping the torch implementation asis to support for StageNet. I did update for the create dataset portions in the unit test and example script. Hope this helps. Please advise if futher changes may be required.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one last doc change, and then lgtm! Will more robustly test and decide if we need revisions as we interpret the model/look into it deeper. (i.e we might need to compare our implementation with another given some workarounds).

@naveenkcb
Copy link
Contributor Author

@jhnwu3 I added the comment to interpret.rst file and pushed my changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants