diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index d6054e33..44aa0411 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -1,6 +1,7 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from typing import Literal import pandas as pd import torch @@ -9,6 +10,7 @@ from ..files import elk_reporter_dir from ..metrics import evaluate_preds from ..run import Run +from ..training.multi_reporter import MultiReporter, SingleReporter from ..utils import Color @@ -30,7 +32,7 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( - self, layer: int, devices: list[str], world_size: int + self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool ) -> dict[str, pd.DataFrame]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) @@ -38,39 +40,70 @@ def apply_to_layer( experiment_dir = elk_reporter_dir() / self.source - reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" - reporter = torch.load(reporter_path, map_location=device) + def load_reporter() -> SingleReporter | MultiReporter: + # check if experiment_dir / "reporters" has .pt files + first = next((experiment_dir / "reporters").iterdir()) + if not first.suffix == ".pt": + return MultiReporter.load( + experiment_dir / "reporters", layer, device=device + ) + else: + path = experiment_dir / "reporters" / f"layer_{layer}.pt" + return torch.load(path, map_location=device) + + reporter = load_reporter() row_bufs = defaultdict(list) - for ds_name, (val_h, val_gt, _) in val_output.items(): - meta = {"dataset": ds_name, "layer": layer} - - val_credences = reporter(val_h) - for mode in ("none", "partial", "full"): - row_bufs["eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), - } - ) - lr_dir = experiment_dir / "lr_models" - if not self.skip_supervised and lr_dir.exists(): - with open(lr_dir / f"layer_{layer}.pt", "rb") as f: - lr_models = torch.load(f, map_location=device) - if not isinstance(lr_models, list): # backward compatibility - lr_models = [lr_models] - - for i, model in enumerate(lr_models): - model.eval() - row_bufs["lr_eval"].append( - { - "ensembling": mode, - "inlp_iter": i, - **meta, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), - } - ) + def eval_all( + reporter: SingleReporter | MultiReporter, + prompt_index: int | Literal["multi"] | None = None, + i: int = 0, + ): + prompt_index_dict = ( + {"prompt_index": prompt_index} if prompt_index is not None else {} + ) + for ds_name, (val_h, val_gt, _) in val_output.items(): + meta = {"dataset": ds_name, "layer": layer} + val_credences = reporter(val_h[:, [i], :, :]) + + for mode in ("none", "partial", "full"): + row_bufs["eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_credences, mode).to_dict(), + **prompt_index_dict, + } + ) + + lr_dir = experiment_dir / "lr_models" + if not self.skip_supervised and lr_dir.exists(): + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: + lr_models = torch.load(f, map_location=device) + if not isinstance( + lr_models, list + ): # backward compatibility + lr_models = [lr_models] + + for i, model in enumerate(lr_models): + model.eval() + row_bufs["lr_eval"].append( + { + "ensembling": mode, + "inlp_iter": i, + **meta, + **evaluate_preds( + val_gt, model(val_h), mode + ).to_dict(), + } + ) + + if isinstance(reporter, MultiReporter): + for i, res in enumerate(reporter.reporter_w_infos): + eval_all(res.model, res.prompt_index, i) + eval_all(reporter, "multi") + else: + eval_all(reporter) return {k: pd.DataFrame(v) for k, v in row_bufs.items()} diff --git a/elk/run.py b/elk/run.py index fb8903cc..a17b01cb 100644 --- a/elk/run.py +++ b/elk/run.py @@ -30,6 +30,8 @@ select_usable_devices, ) +PreparedData = dict[str, tuple[Tensor, Tensor, Tensor | None]] + @dataclass class Run(ABC, Serializable): @@ -46,11 +48,14 @@ class Run(ABC, Serializable): prompt_indices: tuple[int, ...] = () """The indices of the prompt templates to use. If empty, all prompts are used.""" + probe_per_prompt: bool = False + """If true, a probe is trained per prompt template. Otherwise, a single probe is + trained for all prompt templates.""" + concatenated_layer_offset: int = 0 debug: bool = False min_gpu_mem: int | None = None # in bytes num_gpus: int = -1 - out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) def execute( @@ -99,13 +104,16 @@ def execute( devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) func: Callable[[int], dict[str, pd.DataFrame]] = partial( - self.apply_to_layer, devices=devices, world_size=num_devices + self.apply_to_layer, + devices=devices, + world_size=num_devices, + probe_per_prompt=self.probe_per_prompt, ) self.apply_to_layers(func=func, num_devices=num_devices) @abstractmethod def apply_to_layer( - self, layer: int, devices: list[str], world_size: int + self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool ) -> dict[str, pd.DataFrame]: """Train or eval a reporter on a single layer.""" @@ -125,7 +133,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: + ) -> PreparedData: """Prepare data for the specified layer and split type.""" out = {} @@ -136,7 +144,7 @@ def prepare_data( labels = assert_type(Tensor, split["label"]) hiddens = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) if self.prompt_indices: - hiddens = hiddens[:, self.prompt_indices] + hiddens = hiddens[:, self.prompt_indices, ...] with split.formatted_as("torch", device=device): has_preds = "model_logits" in split.features @@ -186,7 +194,18 @@ def apply_to_layers( finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): - df = pd.concat(dfs).sort_values(by=["layer", "ensembling"]) - df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) + sortby = ["layer", "ensembling"] + if "prompt_index" in dfs[0].columns: + sortby.append("prompt_index") + df = pd.concat(dfs).sort_values(by=sortby) + + if "prompt_index" in df.columns: + cols = list(df.columns) + cols.insert(2, cols.pop(cols.index("prompt_index"))) + df = df.reindex(columns=cols) + + # Save the CSV + out_path = self.out_dir / f"{name}.csv" + df.round(4).to_csv(out_path, index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) diff --git a/elk/training/multi_reporter.py b/elk/training/multi_reporter.py new file mode 100644 index 00000000..602b23e9 --- /dev/null +++ b/elk/training/multi_reporter.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass +from pathlib import Path + +import torch as t + +from elk.training import CcsReporter +from elk.training.common import Reporter + +SingleReporter = CcsReporter | Reporter + + +@dataclass +class ReporterWithInfo: # I don't love this name but I have no choice because + # of the other Reporter + model: SingleReporter + train_loss: float | None = None + prompt_index: int | None = None + + +class MultiReporter: + def __init__(self, reporter: list[ReporterWithInfo]): + assert len(reporter) > 0, "Must have at least one reporter" + self.reporter_w_infos: list[ReporterWithInfo] = reporter + self.models = [r.model for r in reporter] + train_losses = ( + [r.train_loss for r in reporter] + if reporter[0].train_loss is not None + else None + ) + + self.train_loss = ( + sum(train_losses) / len(train_losses) # type: ignore + if train_losses is not None + else None + ) + + def __call__(self, h): + num_variants = h.shape[1] + assert len(self.models) == num_variants + credences = [] + for i, reporter in enumerate(self.models): + credences.append(reporter(h[:, [i], :, :])) + return t.stack(credences, dim=0).mean(dim=0) + + @staticmethod + def load(path: Path, layer: int, device: str): + prompt_folders = [p for p in path.iterdir() if p.is_dir()] + reporters = [ + ( + t.load(folder / "reporters" / f"layer_{layer}.pt", map_location=device), + int(folder.name.split("_")[-1]), # prompt index + ) + for folder in prompt_folders + ] + # we don't care about the train losses for evaluating + return MultiReporter([ReporterWithInfo(r, None, pi) for r, pi in reporters]) diff --git a/elk/training/train.py b/elk/training/train.py index 8392f2d9..b0af8b39 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,23 +1,113 @@ """Main training loop.""" from collections import defaultdict -from dataclasses import dataclass -from pathlib import Path +from dataclasses import dataclass, replace from typing import Literal import pandas as pd import torch from einops import rearrange, repeat from simple_parsing import subgroups -from simple_parsing.helpers.serialization import save +from ..evaluation import Eval from ..metrics import evaluate_preds, to_one_hot -from ..run import Run +from ..run import PreparedData, Run from ..training.supervised import train_supervised -from ..utils.typing import assert_type +from . import Classifier from .ccs_reporter import CcsConfig, CcsReporter from .common import FitterConfig from .eigen_reporter import EigenFitter, EigenFitterConfig +from .multi_reporter import MultiReporter, ReporterWithInfo, SingleReporter + + +def evaluate_and_save( + train_loss: float | None, + reporter: SingleReporter | MultiReporter, + train_dict: PreparedData, + val_dict: PreparedData, + lr_models: list[Classifier], + layer: int, +): + row_bufs = defaultdict(list) + for ds_name in val_dict: + val_h, val_gt, val_lm_preds = val_dict[ds_name] + train_h, train_gt, train_lm_preds = train_dict[ds_name] + meta = {"dataset": ds_name, "layer": layer} + + def eval_all( + reporter: SingleReporter | MultiReporter, + prompt_index: int | Literal["multi"] | None = None, + i: int = 0, + ): + if isinstance(prompt_index, int): + val_credences = reporter(val_h[:, [i], :, :]) + train_credences = reporter(train_h[:, [i], :, :]) + else: + val_credences = reporter(val_h) + train_credences = reporter(train_h) + prompt_index_dict = ( + {"prompt_index": prompt_index} if prompt_index is not None else {} + ) + for mode in ("none", "partial", "full"): + row_bufs["eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_credences, mode).to_dict(), + "train_loss": train_loss, + **prompt_index_dict, + } + ) + + row_bufs["train_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(train_gt, train_credences, mode).to_dict(), + "train_loss": train_loss, + **prompt_index_dict, + } + ) + + if val_lm_preds is not None: + row_bufs["lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **prompt_index_dict, + } + ) + + if train_lm_preds is not None: + row_bufs["train_lm_eval"].append( + { + **meta, + "ensembling": mode, + **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), + **prompt_index_dict, + } + ) + + for lr_model_num, model in enumerate(lr_models): + row_bufs["lr_eval"].append( + { + **meta, + "ensembling": mode, + "inlp_iter": lr_model_num, + **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **prompt_index_dict, + } + ) + + if isinstance(reporter, MultiReporter): + for i, reporter_result in enumerate(reporter.reporter_w_infos): + eval_all(reporter_result.model, reporter_result.prompt_index, i) + eval_all(reporter, prompt_index="multi") + else: + eval_all(reporter, prompt_index=None) + + return {k: pd.DataFrame(v) for k, v in row_bufs.items()} @dataclass @@ -34,35 +124,31 @@ class Elicit(Run): cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" - def create_models_dir(self, out_dir: Path): - lr_dir = None - lr_dir = out_dir / "lr_models" - reporter_dir = out_dir / "reporters" - - lr_dir.mkdir(parents=True, exist_ok=True) - reporter_dir.mkdir(parents=True, exist_ok=True) - - # Save the reporter config separately in the reporter directory - # for convenient loading of reporters later. - save(self.net, reporter_dir / "cfg.yaml", save_dc_types=True) - - return reporter_dir, lr_dir - - def apply_to_layer( - self, - layer: int, - devices: list[str], - world_size: int, - ) -> dict[str, pd.DataFrame]: - """Train a single reporter on a single layer.""" - - self.make_reproducible(seed=self.net.seed + layer) - device = self.get_device(devices, world_size) - - train_dict = self.prepare_data(device, layer, "train") - val_dict = self.prepare_data(device, layer, "val") - - (first_train_h, train_gt, _), *rest = train_dict.values() + def make_eval(self, model, eval_dataset): + assert self.out_dir is not None + return Eval( + data=replace( + self.data, + model=model, + datasets=(eval_dataset,), + ), + source=self.out_dir, + out_dir=self.out_dir / "transfer" / eval_dataset, + num_gpus=self.num_gpus, + min_gpu_mem=self.min_gpu_mem, + skip_supervised=self.supervised == "none", + prompt_indices=self.prompt_indices, + concatenated_layer_offset=self.concatenated_layer_offset, + # datasets isn't needed because it's immediately overwritten + debug=self.debug, + disable_cache=self.disable_cache, + ) + + # Create a separate function to handle the reporter training. + def train_and_save_reporter( + self, device, layer, out_dir, train_dict, prompt_index=None + ) -> ReporterWithInfo: + (first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove? (_, v, k, d) = first_train_h.shape if not all(other_h.shape[-1] == d for other_h, _, _ in rest): raise ValueError("All datasets must have the same hidden state size") @@ -74,16 +160,12 @@ def apply_to_layer( if not all(other_h.shape[-2] == k for other_h, _, _ in rest): raise ValueError("All datasets must have the same number of classes") - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) train_loss = None - if isinstance(self.net, CcsConfig): assert len(train_dict) == 1, "CCS only supports single-task training" - reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - (_, v, k, _) = first_train_h.shape reporter.platt_scale( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), rearrange(first_train_h, "n v k d -> (n v k) d"), @@ -115,73 +197,93 @@ def apply_to_layer( raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk - torch.save(reporter, reporter_dir / f"layer_{layer}.pt") + # TODO have to change this + out_dir.mkdir(parents=True, exist_ok=True) + torch.save(reporter, out_dir / f"layer_{layer}.pt") + + return ReporterWithInfo(reporter, train_loss, prompt_index) - # Fit supervised logistic regression model + def train_lr_model(self, train_dict, device, layer, out_dir) -> list[Classifier]: if self.supervised != "none": lr_models = train_supervised( train_dict, device=device, mode=self.supervised, ) - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + # make dir if not exists + out_dir.mkdir(parents=True, exist_ok=True) + with open(out_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_models, file) else: lr_models = [] - row_bufs = defaultdict(list) - for ds_name in val_dict: - val_h, val_gt, val_lm_preds = val_dict[ds_name] - train_h, train_gt, train_lm_preds = train_dict[ds_name] - meta = {"dataset": ds_name, "layer": layer} + return lr_models - val_credences = reporter(val_h) - train_credences = reporter(train_h) - for mode in ("none", "partial", "full"): - row_bufs["eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), - "train_loss": train_loss, - } - ) + def apply_to_layer( + self, + layer: int, + devices: list[str], + world_size: int, + probe_per_prompt: bool, + ) -> dict[str, pd.DataFrame]: + """Train a single reporter on a single layer.""" + assert self.out_dir is not None # TODO this is really annoying, why can it be + # None? - row_bufs["train_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_credences, mode).to_dict(), - "train_loss": train_loss, - } - ) + self.make_reproducible(seed=self.net.seed + layer) + device = self.get_device(devices, world_size) - if val_lm_preds is not None: - row_bufs["lm_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), - } - ) + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") - if train_lm_preds is not None: - row_bufs["train_lm_eval"].append( - { - **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), - } - ) + (first_train_h, train_gt, _), *rest = train_dict.values() + (_, v, k, d) = first_train_h.shape - for i, model in enumerate(lr_models): - row_bufs["lr_eval"].append( - { - **meta, - "ensembling": mode, - "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), - } + if probe_per_prompt: + # self.prompt_indices being () actually means "all prompts" + prompt_indices = self.prompt_indices if self.prompt_indices else range(v) + prompt_train_dicts = [ + { + ds_name: ( + train_h[:, [i], ...], + train_gt, + lm_preds[:, [i], ...] if lm_preds is not None else None, ) + } + for ds_name, (train_h, _, lm_preds) in train_dict.items() + for i, _ in enumerate(prompt_indices) + ] + + results = [] + + for prompt_index, prompt_train_dict in zip( + prompt_indices, prompt_train_dicts + ): + assert prompt_index < 100 # format i as a 2 digit string + str_i = str(prompt_index).zfill(2) + base = self.out_dir / "reporters" / f"prompt_{str_i}" + reporters_path = base / "reporters" + + reporter_train_result = self.train_and_save_reporter( + device, layer, reporters_path, prompt_train_dict, prompt_index + ) + results.append(reporter_train_result) + + # it is called maybe_multi_reporter because it might be a single reporter + maybe_multi_reporter = MultiReporter(results) + train_loss = maybe_multi_reporter.train_loss + else: + reporter_train_result = self.train_and_save_reporter( + device, layer, self.out_dir / "reporters", train_dict + ) + + maybe_multi_reporter = reporter_train_result.model + train_loss = reporter_train_result.train_loss + + lr_models = self.train_lr_model( + train_dict, device, layer, self.out_dir / "lr_models" + ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return evaluate_and_save( + train_loss, maybe_multi_reporter, train_dict, val_dict, lr_models, layer + )