Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ nohup.out
.idea
*.pkl

# Don't track any working/testing notebooks
*.ipynb

# Contributors may have their own custom VS code debug configurations
.vscode/launch.json

# scripts for experiments in progress
my_*.sh

Expand Down
35 changes: 20 additions & 15 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
Expand All @@ -32,6 +32,7 @@
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_config,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
Expand Down Expand Up @@ -117,8 +118,8 @@ def extract_hiddens(
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
cfg.model, device, torch_dtype="auto" if device != "cpu" else torch.float32
)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)
Expand Down Expand Up @@ -239,17 +240,21 @@ def extract_hiddens(
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]

# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [h[..., -1, :] for h in hiddens]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")
if not cfg.model.startswith("BlinkDL/rwkv"):
hiddens = [hiddens[i] for i in layer_indices]

if has_per_token_states := len(hiddens[-1].shape) > 1:
# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [
h[..., -1, :] if len(h.shape) >= 2 else h for h in hiddens
]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)
Expand Down Expand Up @@ -306,7 +311,7 @@ def get_splits() -> SplitDict:
dataset_name=available_splits.dataset_name,
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
model_cfg = instantiate_config(cfg.model)

ds_name, config_name = extract_dataset_name_and_config(
dataset_config_str=cfg.prompts.datasets[0]
Expand Down
Loading