-
Notifications
You must be signed in to change notification settings - Fork 555
SHAP Interpretability method implementation #611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
examples/shap_stagenet_mimic4.py
Outdated
| # Initialize SHAP explainer with custom parameters | ||
| shap_explainer = ShapExplainer( | ||
| model, | ||
| use_embeddings=True, # Use embeddings for discrete features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think SHAP should be compatible with discrete tokens like ICD codes here? Correct me if I'm wrong. Will look deeper into understanding the full implementation of SHAP here later when I'm more congitively sound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the SHAPExplainer class instance creation to pass the just the model and default all other values including "use_embeddings" inside the init method. Yes the SHAP works for ICD codes but will use the embeddings from the input model.
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some other nice to haves:
- Can we add an entry to docs/api/interpret/shap.rst ? And add its entry in the interpretability index interpret.rst?
- Can we check that this is compatible when the device is on GPU? Maybe, through a colab notebook? (There's a way to install the branch/repo to the colab environment)
- I might be able to share some compute resources soon once NCSA gets back to me.
pyhealth/interpret/methods/shap.py
Outdated
| if coalition_size == 0 or coalition_size == n_features: | ||
| return torch.tensor(1000.0) # Large weight for edge cases | ||
|
|
||
| comb_val = math.comb(n_features - 1, coalition_size - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, isn't it binom(M, |z|) here? Why do we take n_features -1 and coalition-size -1 instead of n_features and coalition_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code to match the equation used in the original SHAP paper in the method _compute_kernel_weight
weight = (M - 1) / (binom(M, |z|) * |z| * (M - |z|))
- I also added the .rst file as requested.
- Added "examples/shap_stagenet_mimic4.ipynb" using colab with GPU
| coalition_vectors = [] | ||
| coalition_weights = [] | ||
| coalition_preds = [] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if we don't need to add the edge case coalitions specifically (full features and no features) in the prediction set for training the kernel/linear model for predicting shapley values here?
I've linked some captum code examples here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handled the edge cases and updated the code accordingly in the method _compute_kernel_shap
| ) | ||
|
|
||
| # Sample remaining coalitions randomly (excluding edge cases already added) | ||
| n_random_coalitions = max(0, n_coalitions - 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey one last request, as I've been digging deeper into the official shap implementation of kernel regression. They of course do it in numpy, which is mostly so people can interpret random forests and XGBoost.
But, I think we can adopt some of the nice tricks in coalition sampling here:
https://github.com/shap/shap/blob/master/shap/explainers/_kernel.py
Specifically, it seems they do some type of compliment sampling to optimize how much coverage of samples we can sample at a time.
Let me know if you need any help with this. It does look a little complicated, but everything else looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jhnwu3 based on our discord convo, I am keeping the torch implementation asis to support for StageNet. I did update for the create dataset portions in the unit test and example script. Hope this helps. Please advise if futher changes may be required.
jhnwu3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one last doc change, and then lgtm! Will more robustly test and decide if we need revisions as we interpret the model/look into it deeper. (i.e we might need to compare our implementation with another given some workarounds).
|
@jhnwu3 I added the comment to interpret.rst file and pushed my changes. |
Contributor: Naveen Baskaran
Contribution Type: Interpretability method, Tests, Example
Description
This PR implements the SHAP (SHapley Additive exPlanations) interpretability method for PyHealth models, enabling users to understand which features contribute most to model predictions. SHAP is based on coalitional game theory and provides theoretically grounded feature importance scores with desirable properties like local accuracy, missingness, and consistency.
Files to Review
pyhealth/interpret/methods/init.py
pyhealth/interpret/methods/shap.py - Core SHAP method implementation. Suports embedding based attribution, continuous feature support
pyhealth/processors/tensor_processor.py - minor fix to resolve warning message
examples/shap_stagenet_mimic4.py - Example script showing the usage of SHAP method
tests/core/test_shap.py - added comprehensive test cases to test the main class, utility methods and attribution methods.
Results on mimic4-demo dataset