diff --git a/elk/training/common.py b/elk/training/common.py index d93ff006..62d9a6b1 100644 --- a/elk/training/common.py +++ b/elk/training/common.py @@ -18,7 +18,7 @@ class FitterConfig(Serializable, decode_into_subclasses=True): @dataclass class Reporter(PlattMixin): weight: Tensor - eraser: LeaceEraser + eraser: LeaceEraser | None = None def __post_init__(self): # Platt scaling parameters @@ -27,5 +27,8 @@ def __post_init__(self): def __call__(self, hiddens: Tensor) -> Tensor: """Return the predicted log odds on input `x`.""" - raw_scores = self.eraser(hiddens) @ self.weight.mT + if self.eraser is not None: + hiddens = self.eraser(hiddens) + + raw_scores = hiddens @ self.weight.mT return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) diff --git a/elk/training/lda.py b/elk/training/lda.py new file mode 100644 index 00000000..66a1a74c --- /dev/null +++ b/elk/training/lda.py @@ -0,0 +1,98 @@ +"""An ELK reporter network.""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from concept_erasure import optimal_linear_shrinkage +from einops import rearrange +from torch import Tensor + +from ..utils.math_util import cov, cov_mean_fused +from .common import FitterConfig, Reporter + + +@dataclass +class LdaConfig(FitterConfig): + """Configuration for an LdaFitter.""" + + anchor_gamma: float = 1.0 + """Gamma parameter for anchor regression.""" + + invariance_weight: float = 0.5 + """Weight of the prompt invariance term in the loss.""" + + l2_penalty: float = 0.0 + + def __post_init__(self): + assert self.anchor_gamma >= 0, "anchor_gamma must be non-negative" + assert 0 <= self.invariance_weight <= 1, "invariance_weight must be in [0, 1]" + assert self.l2_penalty >= 0, "l2_penalty must be non-negative" + + +class LdaFitter: + """Linear Discriminant Analysis (LDA)""" + + config: LdaConfig + + def __init__(self, cfg: LdaConfig): + super().__init__() + self.config = cfg + + def fit(self, hiddens: Tensor, labels: Tensor) -> Reporter: + """Fit the probe to the contrast set `hiddens`. + + Args: + hiddens: The contrast set of shape [batch, variants, (choices,) dim]. + labels: Integer labels of shape [batch]. + """ + n, v, *_ = hiddens.shape + assert n == labels.shape[0], "hiddens and labels must have the same batch size" + + # This is a contrast set; create a true-false label for each element + if len(hiddens.shape) == 4: + hiddens = rearrange(hiddens, "n v k d -> (n k) v d") + labels = F.one_hot(labels.long()).flatten() + + n = len(labels) + counts = (labels.sum(), n - labels.sum()) + else: + counts = torch.bincount(labels) + assert len(counts) == 2, "Only binary classification is supported for now" + + # Construct targets for the least-squares dual problem + z = torch.where(labels.bool(), n / counts[0], -n / counts[1]).unsqueeze(1) + + # Adjust X and Z for anchor regression + gamma = self.config.anchor_gamma + if gamma != 1.0: + # Implicitly compute n x n orthogonal projection onto the column space of + # the anchor variables without materializing the whole matrix. Since the + # anchors are one-hot, it turns out this is equivalent to adding a multiple + # of the anchor-conditional means. + # In general you're supposed to adjust the labels too, but we don't need + # to do that because by construction the anchor-conditional means of the + # labels are already all zero. + hiddens = hiddens + (gamma**0.5 - 1) * hiddens.mean(0) + + # We can decompose the covariance matrix into the sum of the within-cluster + # covariance and the between-cluster covariance. This allows us to put extra + # weight on the within-cluster variance to encourage invariance to the prompt. + # NOTE: We're not applying shrinkage to each cluster covariance matrix because + # we're averaging over them, which should reduce the variance of the estimate + # a lot. Shrinkage could make MSE worse in this case. + S_between = optimal_linear_shrinkage(cov(hiddens.mean(1)), n) + S_within = cov_mean_fused(hiddens) + + # Convex combination but multiply by 2 to keep the same scale + alpha = 2 * self.config.invariance_weight + S = alpha * S_within + (2 - alpha) * S_between + + # Add ridge penalty + torch.linalg.diagonal(S).add_(self.config.l2_penalty) + + # Broadcast the labels across variants + sigma_xz = cov(hiddens, z.expand_as(hiddens[..., 0]).unsqueeze(-1)) + w = torch.linalg.solve(S, sigma_xz.squeeze(-1)) + + return Reporter(w[None]) diff --git a/elk/training/train.py b/elk/training/train.py index 8392f2d9..12710352 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -18,6 +18,7 @@ from .ccs_reporter import CcsConfig, CcsReporter from .common import FitterConfig from .eigen_reporter import EigenFitter, EigenFitterConfig +from .lda import LdaConfig, LdaFitter @dataclass @@ -25,7 +26,12 @@ class Elicit(Run): """Full specification of a reporter training run.""" net: FitterConfig = subgroups( - {"ccs": CcsConfig, "eigen": EigenFitterConfig}, default="eigen" + { + "ccs": CcsConfig, + "eigen": EigenFitterConfig, + "lda": LdaConfig, + }, + default="eigen", ) """Config for building the reporter network.""" @@ -74,6 +80,7 @@ def apply_to_layer( if not all(other_h.shape[-2] == k for other_h, _, _ in rest): raise ValueError("All datasets must have the same number of classes") + train_loss = None reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) train_loss = None @@ -111,6 +118,8 @@ def apply_to_layer( torch.cat(label_list), torch.cat(hidden_list), ) + elif isinstance(self.net, LdaConfig): + reporter = LdaFitter(self.net).fit(first_train_h, train_gt) else: raise ValueError(f"Unknown reporter config type: {type(self.net)}") diff --git a/elk/utils/__init__.py b/elk/utils/__init__.py index 22b92b75..316f7eb6 100644 --- a/elk/utils/__init__.py +++ b/elk/utils/__init__.py @@ -10,14 +10,14 @@ ) from .gpu_utils import select_usable_devices from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive -from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained +from .math_util import cov, cov_mean_fused, stochastic_round_constrained from .pretty import Color, colorize from .tree_utils import pytree_map from .typing import assert_type, float_to_int16, int16_to_float32 __all__ = [ "assert_type", - "batch_cov", + "cov", "Color", "colorize", "cov_mean_fused", diff --git a/elk/utils/math_util.py b/elk/utils/math_util.py index 4ae9daee..d901518c 100644 --- a/elk/utils/math_util.py +++ b/elk/utils/math_util.py @@ -5,18 +5,28 @@ from torch import Tensor -@torch.jit.script -def batch_cov(x: Tensor) -> Tensor: - """Compute a batch of covariance matrices. +def cov( + x: Tensor, y: Tensor | None = None, dim: int | None = None, unbiased: bool = False +) -> Tensor: + """Compute the (cross-)covariance matrix for `x` (and `y`). Args: - x: A tensor of shape [..., n, d]. - - Returns: - A tensor of shape [..., d, d]. + x: A tensor of shape [*, d]. + y: An optional tensor of shape [*, k]. If not provided, defaults to `x`. + dim: The dimension to reduce over. If not provided, defaults to all but the + last dimension. + unbiased: Whether to use Bessel's correction. """ - x_ = x - x.mean(dim=-2, keepdim=True) - return x_.mT @ x_ / x_.shape[-2] + if y is None: + y = x + if dim is None: + dim = 0 + x = x.flatten(0, -2) + y = y.flatten(0, -2) + + x = x - x.mean(dim) + y = y - y.mean(dim) + return x.T @ y / (x.shape[dim] - unbiased) @torch.jit.script diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index 6303cb03..815790f3 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,7 +1,7 @@ import torch from elk.training import EigenFitter, EigenFitterConfig -from elk.utils import batch_cov, cov_mean_fused +from elk.utils import cov, cov_mean_fused def test_eigen_reporter(): @@ -31,7 +31,7 @@ def test_eigen_reporter(): # Check that the streaming covariance is correct neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1) - true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids)) + true_cov = 0.5 * (cov(neg_centroids) + cov(pos_centroids)) torch.testing.assert_close(reporter.intercluster_cov, true_cov) # Check that the streaming negative covariance is correct diff --git a/tests/test_math.py b/tests/test_math.py index 34984d8f..07b677aa 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -6,12 +6,12 @@ from hypothesis import given from hypothesis import strategies as st -from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained +from elk.utils import cov, cov_mean_fused, stochastic_round_constrained def test_cov_mean_fused(): X = torch.randn(10, 500, 100, dtype=torch.float64) - cov_gt = batch_cov(X).mean(dim=0) + cov_gt = cov(X).mean(dim=0) cov_fused = cov_mean_fused(X) assert torch.allclose(cov_gt, cov_fused)