-
Notifications
You must be signed in to change notification settings - Fork 32
Ensembling over layers #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
934cd54
dff69bf
b181d3e
a493b85
af5def6
04a2a82
f433885
cda7de7
c9f2558
0ceaa3a
86fb1c8
47a3f60
0bd274f
04f0b4c
994af9b
6ca1916
b0d0f83
241a03a
e8d042a
a4ace25
b025c71
2156ad8
e391da6
449971f
528367d
d4df517
2661ea1
043aa7a
21cccb7
2495c3a
69af43c
908308b
fc980d7
5aa30a9
d5b8584
6380814
98d19b7
e6914e1
29b1cb8
421590c
bed615a
bf49e99
1f5d8be
03a37d2
ec37716
1936624
b243932
484788e
8c34797
ea5e9e8
c0545aa
b6de957
cf32b0c
769676a
e6c9d4c
8093294
6028152
6d7d99a
4148857
964f03d
f7ed262
06dad69
0d2545b
4a717ce
5952b4b
7efe38f
049cd63
c8236dd
56d1796
4d9c781
8961e95
bd06cd3
f8882c6
23183bc
d091f9d
776c186
9629ba5
45b527f
64e762a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,22 @@ | ||
| from dataclasses import asdict, dataclass | ||
| from typing import Literal | ||
|
|
||
| import torch | ||
| from einops import repeat | ||
| from torch import Tensor | ||
|
|
||
| from ..utils.types import PromptEnsembling | ||
| from .accuracy import AccuracyResult, accuracy_ci | ||
| from .calibration import CalibrationError, CalibrationEstimate | ||
| from .roc_auc import RocAucResult, roc_auc_ci | ||
|
|
||
|
|
||
| @dataclass | ||
| class LayerOutput: | ||
| val_gt: Tensor | ||
| val_credences: Tensor | ||
| meta: dict | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class EvalResult: | ||
| """The result of evaluating a classifier.""" | ||
|
|
@@ -26,7 +33,7 @@ class EvalResult: | |
| cal_thresh: float | None | ||
| """The threshold used to compute the calibrated accuracy.""" | ||
|
|
||
| def to_dict(self, prefix: str = "") -> dict[str, float]: | ||
| def to_dict(self, prefix: str = "") -> dict[str, float | None]: | ||
| """Convert the result to a dictionary.""" | ||
| acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()} | ||
| cal_acc_dict = ( | ||
|
|
@@ -49,67 +56,164 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: | |
| } | ||
|
|
||
|
|
||
| def calc_auroc( | ||
| y_logits: Tensor, | ||
| y_true: Tensor, | ||
| ensembling: PromptEnsembling, | ||
| num_classes: int, | ||
| ) -> RocAucResult: | ||
| """ | ||
| Calculate the AUROC | ||
|
|
||
| Args: | ||
| y_true: Ground truth tensor of shape (n,). | ||
| y_logits: Predicted class tensor of shape (n, num_variants, num_classes). | ||
| prompt_ensembling: The prompt_ensembling mode. | ||
| num_classes: The number of classes. | ||
|
|
||
| Returns: | ||
| RocAucResult: A dictionary containing the AUROC and confidence interval. | ||
| """ | ||
| if ensembling == PromptEnsembling.NONE: | ||
| auroc = roc_auc_ci( | ||
| to_one_hot(y_true, num_classes).long().flatten(1), y_logits.flatten(1) | ||
| ) | ||
| elif ensembling in (PromptEnsembling.PARTIAL, PromptEnsembling.FULL): | ||
| # Pool together the negative and positive class logits | ||
| if num_classes == 2: | ||
| auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) | ||
| else: | ||
| auroc = roc_auc_ci(to_one_hot(y_true, num_classes).long(), y_logits) | ||
| else: | ||
| raise ValueError(f"Unknown mode: {ensembling}") | ||
|
|
||
| return auroc | ||
|
|
||
|
|
||
| def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult: | ||
| """ | ||
| Calculate the calibrated accuracies | ||
|
|
||
| Args: | ||
| y_true: Ground truth tensor of shape (n,). | ||
| pos_probs: Predicted class tensor of shape (n, num_variants, num_classes). | ||
|
|
||
| Returns: | ||
| AccuracyResult: A dictionary containing the accuracy and confidence interval. | ||
| """ | ||
|
|
||
| cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() | ||
| cal_preds = pos_probs.gt(cal_thresh).to(torch.int) | ||
| cal_acc = accuracy_ci(y_true, cal_preds, cal_thresh) | ||
| return cal_acc | ||
|
|
||
|
|
||
| def calc_calibrated_errors(y_true, pos_probs) -> CalibrationEstimate: | ||
| """ | ||
| Calculate the expected calibration error. | ||
|
|
||
| Args: | ||
| y_true: Ground truth tensor of shape (n,). | ||
| y_logits: Predicted class tensor of shape (n, num_variants, num_classes). | ||
|
|
||
| Returns: | ||
| CalibrationEstimate: | ||
| """ | ||
|
|
||
| cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) | ||
| cal_err = cal.compute() | ||
| return cal_err | ||
|
|
||
|
|
||
| def calc_accuracies(y_logits, y_true) -> AccuracyResult: | ||
| """ | ||
| Calculate the accuracy | ||
|
|
||
| Args: | ||
| y_true: Ground truth tensor of shape (n,). | ||
| y_logits: Predicted class tensor of shape (n, num_variants, num_classes). | ||
|
|
||
| Returns: | ||
| AccuracyResult: A dictionary containing the accuracy and confidence interval. | ||
| """ | ||
| y_pred = y_logits.argmax(dim=-1) | ||
| return accuracy_ci(y_true, y_pred) | ||
|
|
||
|
|
||
| def evaluate_preds( | ||
| y_true: Tensor, | ||
| y_logits: Tensor, | ||
| ensembling: Literal["none", "partial", "full"] = "none", | ||
| prompt_ensembling: PromptEnsembling = PromptEnsembling.NONE, | ||
| ) -> EvalResult: | ||
| """ | ||
| Evaluate the performance of a classification model. | ||
|
|
||
| Args: | ||
| y_true: Ground truth tensor of shape (N,). | ||
| y_logits: Predicted class tensor of shape (N, variants, n_classes). | ||
| y_true: Ground truth tensor of shape (n,). | ||
| y_logits: Predicted class tensor of shape (n, num_variants, num_classes). | ||
| prompt_ensembling: The prompt_ensembling mode. | ||
|
|
||
| Returns: | ||
| dict: A dictionary containing the accuracy, AUROC, and ECE. | ||
| """ | ||
| (n, v, c) = y_logits.shape | ||
| assert y_true.shape == (n,) | ||
| y_logits, y_true, num_classes = prepare(y_logits, y_true, prompt_ensembling) | ||
| return calc_eval_results(y_true, y_logits, prompt_ensembling, num_classes) | ||
|
|
||
|
|
||
| def prepare(y_logits: Tensor, y_true: Tensor, prompt_ensembling: PromptEnsembling): | ||
| """ | ||
| Prepare the logits and ground truth for evaluation | ||
| """ | ||
| (n, num_variants, num_classes) = y_logits.shape | ||
| assert y_true.shape == (n,), f"y_true.shape: {y_true.shape} is not equal to n: {n}" | ||
|
|
||
| if ensembling == "full": | ||
| if prompt_ensembling == PromptEnsembling.FULL: | ||
| y_logits = y_logits.mean(dim=1) | ||
| else: | ||
| y_true = repeat(y_true, "n -> n v", v=v) | ||
| y_true = repeat(y_true, "n -> n v", v=num_variants) | ||
|
|
||
| THRESHOLD = 0.5 | ||
| if ensembling == "none": | ||
| y_pred = y_logits[..., 1].gt(THRESHOLD).to(torch.int) | ||
| else: | ||
| y_pred = y_logits.argmax(dim=-1) | ||
| return y_logits, y_true, num_classes | ||
|
|
||
| acc = accuracy_ci(y_true, y_pred) | ||
|
|
||
| if ensembling == "none": | ||
| auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) | ||
| elif ensembling in ("partial", "full"): | ||
| # Pool together the negative and positive class logits | ||
| if c == 2: | ||
| auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) | ||
| else: | ||
| auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits) | ||
| else: | ||
| raise ValueError(f"Unknown mode: {ensembling}") | ||
| def calc_eval_results( | ||
| y_true: Tensor, | ||
| y_logits: Tensor, | ||
| prompt_ensembling: PromptEnsembling, | ||
| num_classes: int, | ||
| ) -> EvalResult: | ||
| """ | ||
| Calculate the evaluation results | ||
|
|
||
| cal_acc = None | ||
| cal_err = None | ||
| cal_thresh = None | ||
| Args: | ||
| y_true: Ground truth tensor of shape (n,). | ||
| y_logits: Predicted class tensor of shape (n, num_variants, num_classes). | ||
| prompt_ensembling: The prompt_ensembling mode. | ||
|
|
||
| if c == 2: | ||
| pooled_logits = ( | ||
| y_logits[..., 1] | ||
| if ensembling == "none" | ||
| else y_logits[..., 1] - y_logits[..., 0] | ||
| ) | ||
| pos_probs = torch.sigmoid(pooled_logits) | ||
| Returns: | ||
| EvalResult: The result of evaluating a classifier containing the accuracy, | ||
| calibrated accuracies, calibrated errors, and AUROC. | ||
| """ | ||
| acc = calc_accuracies(y_logits=y_logits, y_true=y_true) | ||
|
|
||
| # Calibrated accuracy | ||
| cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() | ||
| cal_preds = pos_probs.gt(cal_thresh).to(torch.int) | ||
| cal_acc = accuracy_ci(y_true, cal_preds) | ||
| pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) | ||
| cal_acc, cal_thresh = ( | ||
| calc_calibrated_accuracies(y_true=y_true, pos_probs=pos_probs) | ||
| if num_classes == 2 | ||
| else None, | ||
| None, | ||
| ) | ||
| cal_err = ( | ||
| calc_calibrated_errors(y_true=y_true, pos_probs=pos_probs) | ||
| if num_classes == 2 | ||
| else None | ||
| ) | ||
|
|
||
| cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) | ||
| cal_err = cal.compute() | ||
| auroc = calc_auroc( | ||
| y_logits=y_logits, | ||
| y_true=y_true, | ||
| ensembling=prompt_ensembling, | ||
| num_classes=num_classes, | ||
| ) | ||
|
|
||
| return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh) | ||
|
|
||
|
|
@@ -127,3 +231,49 @@ def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: | |
| """ | ||
| one_hot_labels = labels.new_zeros(*labels.shape, n_classes) | ||
| return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1) | ||
|
|
||
|
|
||
| def layer_ensembling( | ||
| layer_outputs: list[LayerOutput], prompt_ensembling: PromptEnsembling | ||
| ) -> EvalResult: | ||
| """ | ||
| Return EvalResult after prompt_ensembling | ||
| the probe output of the middle to last layers | ||
|
|
||
| Args: | ||
| layer_outputs: A list of LayerOutput containing the ground truth and | ||
| predicted class tensor of shape (n, num_variants, num_classes). | ||
| prompt_ensembling: The prompt_ensembling mode. | ||
|
|
||
| Returns: | ||
| EvalResult: The result of evaluating a classifier containing the accuracy, | ||
| calibrated accuracies, calibrated errors, and AUROC. | ||
| """ | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| y_logits_collection = [] | ||
|
|
||
| num_classes = 2 | ||
lauritowal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| y_true = layer_outputs[0].val_gt.to(device) | ||
|
|
||
| for layer_output in layer_outputs: | ||
| # all y_trues are identical, so just get the first | ||
| y_logits = layer_output.val_credences.to(device) | ||
| y_logits, y_true, num_classes = prepare( | ||
| y_logits=y_logits, | ||
| y_true=layer_outputs[0].val_gt.to(device), | ||
| prompt_ensembling=prompt_ensembling, | ||
| ) | ||
| y_logits_collection.append(y_logits) | ||
|
|
||
| # get logits and ground_truth from middle to last layer | ||
| middle_index = len(layer_outputs) // 2 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in some ways I think we should allow the layers over which we ensemble to be configurable. E.g. sometimes the last layers perform worse.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, it makes sense to make it configurable. However, I'm curious, how would you decide which layers to pick? |
||
| y_logits_stacked = torch.stack(y_logits_collection[middle_index:]) | ||
| # layer prompt_ensembling of the stacked logits | ||
| y_logits_stacked_mean = torch.mean(y_logits_stacked, dim=0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the ensembling is done by taking the mean over layers, rather than concatenating. This isn't super clear from comments/docstrings, and hard to tell from reading the code because the shapes aren't commented. |
||
|
|
||
| return calc_eval_results( | ||
| y_true=y_true, | ||
| y_logits=y_logits_stacked_mean, | ||
| prompt_ensembling=prompt_ensembling, | ||
| num_classes=num_classes, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.