From 9f39d10c1666de893db1bd8ac41fbbb7687d3842 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Thu, 6 Jul 2023 03:00:12 +0000 Subject: [PATCH 1/3] LDA MVP --- elk/training/lda_reporter.py | 121 +++++++++++++++++++++++++++++++++++ elk/training/train.py | 25 ++++---- 2 files changed, 135 insertions(+), 11 deletions(-) create mode 100644 elk/training/lda_reporter.py diff --git a/elk/training/lda_reporter.py b/elk/training/lda_reporter.py new file mode 100644 index 00000000..8bed3021 --- /dev/null +++ b/elk/training/lda_reporter.py @@ -0,0 +1,121 @@ +"""An ELK reporter network.""" + +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn + +from .reporter import Reporter, ReporterConfig + + +@dataclass +class LdaReporterConfig(ReporterConfig): + """Configuration for an LdaReporter.""" + + @classmethod + def reporter_class(cls) -> type[Reporter]: + return LdaReporter + + +class LdaReporter(Reporter): + """Linear Discriminant Analysis (LDA) reporter. + + Args: + cfg: The reporter configuration. + in_features: The number of input features. + num_classes: The number of classes for tracking the running means. + + Attributes: + config: The reporter configuration. + n: The running sum of the number of clusters processed by `update()`. + weight: The reporter weight matrix. Guaranteed to always be orthogonal, and + the columns are sorted in descending order of eigenvalue magnitude. + """ + + config: LdaReporterConfig + + n: Tensor + class_means: Tensor + weight: Tensor + + def __init__( + self, + cfg: LdaReporterConfig, + in_features: int, + num_classes: int = 2, + *, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + num_variants: int = 1, + ): + super().__init__() + self.config = cfg + self.in_features = in_features + self.num_classes = num_classes + self.num_variants = num_variants + + # Learnable Platt scaling parameters + self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype)) + self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) + + # Running statistics + self.register_buffer( + "class_means", + torch.zeros(num_classes, in_features, device=device, dtype=dtype), + ) + self.register_buffer( + "sample_sizes", + torch.zeros((), device=device, dtype=torch.long), + ) + # Reporter weights + self.register_buffer( + "weight", + torch.zeros(1, in_features, device=device, dtype=dtype), + ) + + def forward(self, hiddens: Tensor) -> Tensor: + """Return the predicted log odds on input `x`.""" + raw_scores = hiddens @ self.weight.mT + return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + + def fit(self, hiddens: Tensor, labels: Tensor): + """Fit the probe to the contrast set `hiddens`. + + Args: + hiddens: The contrast set of shape [batch, variants, choices, dim]. + + Returns: + loss: Negative eigenvalue associated with the VINC direction. + """ + (n, v, k, d) = hiddens.shape + + # Sanity checks + assert k > 1, "Must provide at least two hidden states" + assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" + + # Create a true-false label for each element of the contrast set + mask = F.one_hot(labels.long(), num_classes=k).bool() # [n, k] + mask = mask[:, None, :].expand_as(hiddens[..., 0]) # [n, 1, k] -> [n, v, k] + + mu_pos = hiddens[mask].mean(0) + mu_neg = hiddens[~mask].mean(0) + sigma = rearrange(hiddens, "n v k d -> d (n v k)").cov() + + w = torch.linalg.pinv(sigma) @ (mu_pos - mu_neg) + self.weight.data = w[None] + + def save(self, path: Path | str) -> None: + """Save the reporter to a file.""" + # We basically never want to instantiate the reporter on the same device + # it happened to be trained on, so we save the state dict as CPU tensors. + # Bizarrely, this also seems to save a LOT of disk space in some cases. + state = {k: v.cpu() for k, v in self.state_dict().items()} + state.update( + in_features=self.in_features, + num_classes=self.num_classes, + num_variants=self.num_variants, + ) + torch.save(state, path) diff --git a/elk/training/train.py b/elk/training/train.py index e0aee29b..323a4a87 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -17,6 +17,7 @@ from ..utils.typing import assert_type from .ccs_reporter import CcsReporter, CcsReporterConfig from .eigen_reporter import EigenReporter, EigenReporterConfig +from .lda_reporter import LdaReporter, LdaReporterConfig from .reporter import ReporterConfig @@ -25,7 +26,12 @@ class Elicit(Run): """Full specification of a reporter training run.""" net: ReporterConfig = subgroups( - {"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen" + { + "ccs": CcsReporterConfig, + "eigen": EigenReporterConfig, + "lda": LdaReporterConfig, + }, + default="eigen", ) """Config for building the reporter network.""" @@ -74,21 +80,15 @@ def apply_to_layer( if not all(other_h.shape[-2] == k for other_h, _, _ in rest): raise ValueError("All datasets must have the same number of classes") + train_loss = None reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) + if isinstance(self.net, CcsReporterConfig): assert len(train_dict) == 1, "CCS only supports single-task training" reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - (val_h, val_gt, _) = next(iter(val_dict.values())) - x0, x1 = first_train_h.unbind(2) - val_x0, val_x1 = val_h.unbind(2) - pseudo_auroc = reporter.check_separability( - train_pair=(x0, x1), - val_pair=(val_x0, val_x1), - ) - (_, v, k, _) = first_train_h.shape reporter.platt_scale( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), @@ -112,12 +112,16 @@ def apply_to_layer( ) reporter.update(train_h) - pseudo_auroc = None train_loss = reporter.fit_streaming() reporter.platt_scale( torch.cat(label_list), torch.cat(hidden_list), ) + elif isinstance(self.net, LdaReporterConfig): + reporter = LdaReporter( + self.net, d, num_classes=k, num_variants=v, device=device + ) + reporter.fit(first_train_h, train_gt) else: raise ValueError(f"Unknown reporter config type: {type(self.net)}") @@ -150,7 +154,6 @@ def apply_to_layer( **meta, "ensembling": mode, **evaluate_preds(val_gt, val_credences, mode).to_dict(), - "pseudo_auroc": pseudo_auroc, "train_loss": train_loss, } ) From b924457f6bf38bb52afe3238ca2f819152a7913a Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 10 Jul 2023 08:02:23 +0000 Subject: [PATCH 2/3] LdaFitter MVP --- elk/training/common.py | 7 +- elk/training/lda.py | 54 ++++++++++++++++ elk/training/lda_reporter.py | 121 ----------------------------------- elk/training/train.py | 11 ++-- 4 files changed, 63 insertions(+), 130 deletions(-) create mode 100644 elk/training/lda.py delete mode 100644 elk/training/lda_reporter.py diff --git a/elk/training/common.py b/elk/training/common.py index d93ff006..62d9a6b1 100644 --- a/elk/training/common.py +++ b/elk/training/common.py @@ -18,7 +18,7 @@ class FitterConfig(Serializable, decode_into_subclasses=True): @dataclass class Reporter(PlattMixin): weight: Tensor - eraser: LeaceEraser + eraser: LeaceEraser | None = None def __post_init__(self): # Platt scaling parameters @@ -27,5 +27,8 @@ def __post_init__(self): def __call__(self, hiddens: Tensor) -> Tensor: """Return the predicted log odds on input `x`.""" - raw_scores = self.eraser(hiddens) @ self.weight.mT + if self.eraser is not None: + hiddens = self.eraser(hiddens) + + raw_scores = hiddens @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) diff --git a/elk/training/lda.py b/elk/training/lda.py new file mode 100644 index 00000000..fdf72b70 --- /dev/null +++ b/elk/training/lda.py @@ -0,0 +1,54 @@ +"""An ELK reporter network.""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from concept_erasure import optimal_linear_shrinkage +from einops import rearrange +from torch import Tensor + +from .common import FitterConfig, Reporter + + +@dataclass +class LdaConfig(FitterConfig): + """Configuration for an LdaFitter.""" + + l2_penalty: float = 0.0 + + +class LdaFitter: + """Linear Discriminant Analysis (LDA)""" + + config: LdaConfig + + def __init__(self, cfg: LdaConfig): + super().__init__() + self.config = cfg + + def fit(self, hiddens: Tensor, labels: Tensor) -> Reporter: + """Fit the probe to the contrast set `hiddens`. + + Args: + hiddens: The contrast set of shape [batch, variants, choices, dim]. + """ + (n, _, k, _) = hiddens.shape + + # Sanity checks + assert k > 1, "Must provide at least two hidden states" + assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" + + # Create a true-false label for each element of the contrast set + mask = F.one_hot(labels.long(), num_classes=k).bool() # [n, k] + mask = mask[:, None, :].expand_as(hiddens[..., 0]) # [n, 1, k] -> [n, v, k] + + mu_pos = hiddens[mask].mean(0) + mu_neg = hiddens[~mask].mean(0) + + sigma = rearrange(hiddens, "n v k d -> d (n v k)").cov() + sigma = optimal_linear_shrinkage(sigma, n) + torch.linalg.diagonal(sigma).add_(self.config.l2_penalty) + + w = torch.linalg.solve(sigma, mu_pos - mu_neg) + return Reporter(w[None]) diff --git a/elk/training/lda_reporter.py b/elk/training/lda_reporter.py deleted file mode 100644 index 8bed3021..00000000 --- a/elk/training/lda_reporter.py +++ /dev/null @@ -1,121 +0,0 @@ -"""An ELK reporter network.""" - -from dataclasses import dataclass -from pathlib import Path - -import torch -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor, nn - -from .reporter import Reporter, ReporterConfig - - -@dataclass -class LdaReporterConfig(ReporterConfig): - """Configuration for an LdaReporter.""" - - @classmethod - def reporter_class(cls) -> type[Reporter]: - return LdaReporter - - -class LdaReporter(Reporter): - """Linear Discriminant Analysis (LDA) reporter. - - Args: - cfg: The reporter configuration. - in_features: The number of input features. - num_classes: The number of classes for tracking the running means. - - Attributes: - config: The reporter configuration. - n: The running sum of the number of clusters processed by `update()`. - weight: The reporter weight matrix. Guaranteed to always be orthogonal, and - the columns are sorted in descending order of eigenvalue magnitude. - """ - - config: LdaReporterConfig - - n: Tensor - class_means: Tensor - weight: Tensor - - def __init__( - self, - cfg: LdaReporterConfig, - in_features: int, - num_classes: int = 2, - *, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - num_variants: int = 1, - ): - super().__init__() - self.config = cfg - self.in_features = in_features - self.num_classes = num_classes - self.num_variants = num_variants - - # Learnable Platt scaling parameters - self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype)) - self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype)) - - # Running statistics - self.register_buffer( - "class_means", - torch.zeros(num_classes, in_features, device=device, dtype=dtype), - ) - self.register_buffer( - "sample_sizes", - torch.zeros((), device=device, dtype=torch.long), - ) - # Reporter weights - self.register_buffer( - "weight", - torch.zeros(1, in_features, device=device, dtype=dtype), - ) - - def forward(self, hiddens: Tensor) -> Tensor: - """Return the predicted log odds on input `x`.""" - raw_scores = hiddens @ self.weight.mT - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - - def fit(self, hiddens: Tensor, labels: Tensor): - """Fit the probe to the contrast set `hiddens`. - - Args: - hiddens: The contrast set of shape [batch, variants, choices, dim]. - - Returns: - loss: Negative eigenvalue associated with the VINC direction. - """ - (n, v, k, d) = hiddens.shape - - # Sanity checks - assert k > 1, "Must provide at least two hidden states" - assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" - - # Create a true-false label for each element of the contrast set - mask = F.one_hot(labels.long(), num_classes=k).bool() # [n, k] - mask = mask[:, None, :].expand_as(hiddens[..., 0]) # [n, 1, k] -> [n, v, k] - - mu_pos = hiddens[mask].mean(0) - mu_neg = hiddens[~mask].mean(0) - sigma = rearrange(hiddens, "n v k d -> d (n v k)").cov() - - w = torch.linalg.pinv(sigma) @ (mu_pos - mu_neg) - self.weight.data = w[None] - - def save(self, path: Path | str) -> None: - """Save the reporter to a file.""" - # We basically never want to instantiate the reporter on the same device - # it happened to be trained on, so we save the state dict as CPU tensors. - # Bizarrely, this also seems to save a LOT of disk space in some cases. - state = {k: v.cpu() for k, v in self.state_dict().items()} - state.update( - in_features=self.in_features, - num_classes=self.num_classes, - num_variants=self.num_variants, - ) - torch.save(state, path) diff --git a/elk/training/train.py b/elk/training/train.py index 77d689c5..12710352 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -18,7 +18,7 @@ from .ccs_reporter import CcsConfig, CcsReporter from .common import FitterConfig from .eigen_reporter import EigenFitter, EigenFitterConfig -from .lda_reporter import LdaReporter, LdaReporterConfig +from .lda import LdaConfig, LdaFitter @dataclass @@ -29,7 +29,7 @@ class Elicit(Run): { "ccs": CcsConfig, "eigen": EigenFitterConfig, - "lda": LdaReporterConfig, + "lda": LdaConfig, }, default="eigen", ) @@ -118,11 +118,8 @@ def apply_to_layer( torch.cat(label_list), torch.cat(hidden_list), ) - elif isinstance(self.net, LdaReporterConfig): - reporter = LdaReporter( - self.net, d, num_classes=k, num_variants=v, device=device - ) - reporter.fit(first_train_h, train_gt) + elif isinstance(self.net, LdaConfig): + reporter = LdaFitter(self.net).fit(first_train_h, train_gt) else: raise ValueError(f"Unknown reporter config type: {type(self.net)}") From 384b97d708113c51125d90eaba84da2d111b731b Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 11 Jul 2023 08:28:38 +0000 Subject: [PATCH 3/3] Anchor regression & counterfactual ridge --- elk/training/lda.py | 80 ++++++++++++++++++++++++++++-------- elk/utils/__init__.py | 4 +- elk/utils/math_util.py | 28 +++++++++---- tests/test_eigen_reporter.py | 4 +- tests/test_math.py | 4 +- 5 files changed, 87 insertions(+), 33 deletions(-) diff --git a/elk/training/lda.py b/elk/training/lda.py index fdf72b70..66a1a74c 100644 --- a/elk/training/lda.py +++ b/elk/training/lda.py @@ -8,6 +8,7 @@ from einops import rearrange from torch import Tensor +from ..utils.math_util import cov, cov_mean_fused from .common import FitterConfig, Reporter @@ -15,8 +16,19 @@ class LdaConfig(FitterConfig): """Configuration for an LdaFitter.""" + anchor_gamma: float = 1.0 + """Gamma parameter for anchor regression.""" + + invariance_weight: float = 0.5 + """Weight of the prompt invariance term in the loss.""" + l2_penalty: float = 0.0 + def __post_init__(self): + assert self.anchor_gamma >= 0, "anchor_gamma must be non-negative" + assert 0 <= self.invariance_weight <= 1, "invariance_weight must be in [0, 1]" + assert self.l2_penalty >= 0, "l2_penalty must be non-negative" + class LdaFitter: """Linear Discriminant Analysis (LDA)""" @@ -31,24 +43,56 @@ def fit(self, hiddens: Tensor, labels: Tensor) -> Reporter: """Fit the probe to the contrast set `hiddens`. Args: - hiddens: The contrast set of shape [batch, variants, choices, dim]. + hiddens: The contrast set of shape [batch, variants, (choices,) dim]. + labels: Integer labels of shape [batch]. """ - (n, _, k, _) = hiddens.shape - - # Sanity checks - assert k > 1, "Must provide at least two hidden states" - assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]" - - # Create a true-false label for each element of the contrast set - mask = F.one_hot(labels.long(), num_classes=k).bool() # [n, k] - mask = mask[:, None, :].expand_as(hiddens[..., 0]) # [n, 1, k] -> [n, v, k] - - mu_pos = hiddens[mask].mean(0) - mu_neg = hiddens[~mask].mean(0) - - sigma = rearrange(hiddens, "n v k d -> d (n v k)").cov() - sigma = optimal_linear_shrinkage(sigma, n) - torch.linalg.diagonal(sigma).add_(self.config.l2_penalty) + n, v, *_ = hiddens.shape + assert n == labels.shape[0], "hiddens and labels must have the same batch size" + + # This is a contrast set; create a true-false label for each element + if len(hiddens.shape) == 4: + hiddens = rearrange(hiddens, "n v k d -> (n k) v d") + labels = F.one_hot(labels.long()).flatten() + + n = len(labels) + counts = (labels.sum(), n - labels.sum()) + else: + counts = torch.bincount(labels) + assert len(counts) == 2, "Only binary classification is supported for now" + + # Construct targets for the least-squares dual problem + z = torch.where(labels.bool(), n / counts[0], -n / counts[1]).unsqueeze(1) + + # Adjust X and Z for anchor regression + gamma = self.config.anchor_gamma + if gamma != 1.0: + # Implicitly compute n x n orthogonal projection onto the column space of + # the anchor variables without materializing the whole matrix. Since the + # anchors are one-hot, it turns out this is equivalent to adding a multiple + # of the anchor-conditional means. + # In general you're supposed to adjust the labels too, but we don't need + # to do that because by construction the anchor-conditional means of the + # labels are already all zero. + hiddens = hiddens + (gamma**0.5 - 1) * hiddens.mean(0) + + # We can decompose the covariance matrix into the sum of the within-cluster + # covariance and the between-cluster covariance. This allows us to put extra + # weight on the within-cluster variance to encourage invariance to the prompt. + # NOTE: We're not applying shrinkage to each cluster covariance matrix because + # we're averaging over them, which should reduce the variance of the estimate + # a lot. Shrinkage could make MSE worse in this case. + S_between = optimal_linear_shrinkage(cov(hiddens.mean(1)), n) + S_within = cov_mean_fused(hiddens) + + # Convex combination but multiply by 2 to keep the same scale + alpha = 2 * self.config.invariance_weight + S = alpha * S_within + (2 - alpha) * S_between + + # Add ridge penalty + torch.linalg.diagonal(S).add_(self.config.l2_penalty) + + # Broadcast the labels across variants + sigma_xz = cov(hiddens, z.expand_as(hiddens[..., 0]).unsqueeze(-1)) + w = torch.linalg.solve(S, sigma_xz.squeeze(-1)) - w = torch.linalg.solve(sigma, mu_pos - mu_neg) return Reporter(w[None]) diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 22b92b75..316f7eb6 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -10,14 +10,14 @@ ) from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive -from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from .math_util import cov, cov_mean_fused, stochastic_round_constrained from .pretty import Color, colorize from .tree_utils import pytree_map from .typing import assert_type, float_to_int16, int16_to_float32 __all__ = [ "assert_type", - "batch_cov", + "cov", "Color", "colorize", "cov_mean_fused", diff --git a/elk/utils/math_util.py b/elk/utils/math_util.py index 4ae9daee..d901518c 100644 --- a/elk/utils/math_util.py +++ b/elk/utils/math_util.py @@ -5,18 +5,28 @@ from torch import Tensor -@torch.jit.script -def batch_cov(x: Tensor) -> Tensor: - """Compute a batch of covariance matrices. +def cov( + x: Tensor, y: Tensor | None = None, dim: int | None = None, unbiased: bool = False +) -> Tensor: + """Compute the (cross-)covariance matrix for `x` (and `y`). Args: - x: A tensor of shape [..., n, d]. - - Returns: - A tensor of shape [..., d, d]. + x: A tensor of shape [*, d]. + y: An optional tensor of shape [*, k]. If not provided, defaults to `x`. + dim: The dimension to reduce over. If not provided, defaults to all but the + last dimension. + unbiased: Whether to use Bessel's correction. """ - x_ = x - x.mean(dim=-2, keepdim=True) - return x_.mT @ x_ / x_.shape[-2] + if y is None: + y = x + if dim is None: + dim = 0 + x = x.flatten(0, -2) + y = y.flatten(0, -2) + + x = x - x.mean(dim) + y = y - y.mean(dim) + return x.T @ y / (x.shape[dim] - unbiased) @torch.jit.script diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 6303cb03..815790f3 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,7 +1,7 @@ import torch from elk.training import EigenFitter, EigenFitterConfig -from elk.utils import batch_cov, cov_mean_fused +from elk.utils import cov, cov_mean_fused def test_eigen_reporter(): @@ -31,7 +31,7 @@ def test_eigen_reporter(): # Check that the streaming covariance is correct neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1) - true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) + true_cov = 0.5 * (cov(neg_centroids) + cov(pos_centroids)) torch.testing.assert_close(reporter.intercluster_cov, true_cov) # Check that the streaming negative covariance is correct diff --git a/tests/test_math.py b/tests/test_math.py index 34984d8f..07b677aa 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -6,12 +6,12 @@ from hypothesis import given from hypothesis import strategies as st -from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained +from elk.utils import cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): X = torch.randn(10, 500, 100, dtype=torch.float64) - cov_gt = batch_cov(X).mean(dim=0) + cov_gt = cov(X).mean(dim=0) cov_fused = cov_mean_fused(X) assert torch.allclose(cov_gt, cov_fused)