-
Notifications
You must be signed in to change notification settings - Fork 75
feat: support for censored likelihoods #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: support for censored likelihoods #91
Conversation
|
@John-Curcio Thanks for opening the PR and your effort! Would need some time to review it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for fitting censored data by introducing a CensoredMixin class that extends univariate distributions to handle interval-censored observations. The mixin overrides objective_fn and metric_fn to compute likelihood functions using cumulative distribution functions (CDFs) for censored intervals.
- Adds
CensoredMixinclass with censored likelihood computation - Implements
CensoredLogNormalandCensoredWeibulldistribution classes - Adds comprehensive test coverage for censored data functionality
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| xgboostlss/distributions/censored_utils.py | Core CensoredMixin class implementing censored likelihood functions |
| xgboostlss/distributions/Weibull.py | Adds CensoredWeibull class inheriting from CensoredMixin and Weibull |
| xgboostlss/distributions/LogNormal.py | Adds CensoredLogNormal class inheriting from CensoredMixin and LogNormal |
| tests/utils.py | Extends test data generation to support censored data scenarios |
| tests/test_distribution_utils/test_censored_utils.py | Test suite validating censored distribution functionality |
| mass = cdf_hi - cdf_low | ||
| log_density = dist.log_prob(low) | ||
| censored_inds = low != hi | ||
| loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log density is computed using only the lower bound, but this should only be used for exact observations (non-censored data). For censored intervals where low != hi, this log_density value is incorrect and shouldn't contribute to the loss.
| loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) | |
| exact_inds = (low == hi) | |
| log_density = dist.log_prob(low[exact_inds]) | |
| loss = -torch.sum(torch.log(mass[~exact_inds])) - torch.sum(log_density) |
| return super().objective_fn(predt, data) | ||
| if data.get_weight().size == 0: | ||
| # initialize weights as ones with correct shape | ||
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a tensor just to get its dtype and then converting back to numpy is inefficient. Consider using weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) directly.
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() | |
| weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) |
| predt, labels, *rest = gen_test_data(model, weights=False, censored=False) | ||
| dmat = rest[-1] | ||
| name_c, loss_c = model.dist.metric_fn(predt, dmat) | ||
| underlying_cls = model.dist.__class__.__mro__[2] |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using hardcoded index [2] in the MRO (Method Resolution Order) is fragile and could break if the inheritance hierarchy changes. Consider using a more explicit approach like checking class names or using hasattr to find the base distribution class.
| underlying_cls = model.dist.__class__.__mro__[2] | |
| # Find the first base class in the MRO that is not a censored distribution and not 'object' | |
| underlying_cls = next( | |
| cls for cls in model.dist.__class__.__mro__ | |
| if cls is not model.dist.__class__ and not cls.__name__.startswith("Censored") and cls is not object | |
| ) |
| dist = self.distribution(**dict(zip(self.distribution_arg_names, params_transformed))) | ||
| # compute cdf bounds: convert lower & upper once to tensor with correct dtype | ||
| low = torch.as_tensor(lower, dtype=params_transformed[0].dtype).reshape(-1, 1) | ||
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent spacing: 'hi =' has two spaces before the equals sign while 'low =' on the previous line has one space. This should be consistent.
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) | |
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) |
|
@claude Review this PR |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
|
@John-Curcio CAn you please look into these |
CensoredMixinclass to extend fitting univariate distributions to censored dataCensoredLogNormalandCensoredWeibullxgb.DMatrixalready haslabel_lower_bound, label_upper_boundfor the user to specify right-, left-, or interval-censored data. This PR adds aCensoredMixinclass which simply overridesobjective_fn, metric_fnto accommodate censored data. So to fit a LogNormal distribution to such a dataset, just useCensoredLogNormalinstead ofLogNormal.I've added
CensoredLogNormalandCensoredWeibull.I'm happy to further update docs/add examples