Skip to content
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ together. You can also add a `--visualize` flag to visualize the results of the
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

If you just do `elk plot`, it will plot the results from the most recent sweep.
If you want to plot a specific sweep, you can do so with:
If you just do `elk plot`, it will plot the results of AUROC from the most recent sweep.
If you want to plot a specific sweep, with a specific metric type, you can do so with:

```bash
elk plot {sweep_name}
elk plot {sweep_name} --metric acc_estimate
```

## Caching
Expand Down
5 changes: 4 additions & 1 deletion elk/plotting/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Plot:
overwrite: bool = False
"""Whether to overwrite existing plots."""

metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

def execute(self):
root_dir = sweeps_dir()

Expand All @@ -47,4 +50,4 @@ def execute(self):
if self.overwrite:
shutil.rmtree(sweep_path / "viz")

visualize_sweep(sweep_path)
visualize_sweep(sweep_path, self.metric_type)
54 changes: 27 additions & 27 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def render(
shared_yaxes=True,
vertical_spacing=0.1,
x_title="Layer",
y_title="AUROC",
y_title=f"{sweep.metric_type}",
)
color_map = dict(zip(ensembles, qualitative.Plotly))

Expand All @@ -56,7 +56,7 @@ def render(
if with_transfer: # TODO write tests
ensemble_data = ensemble_data.groupby(
["eval_dataset", "layer", "ensembling"], as_index=False
).agg({"auroc_estimate": "mean"})
).agg({f"{sweep.metric_type}": "mean"})
else:
ensemble_data = ensemble_data[
ensemble_data["eval_dataset"] == ensemble_data["train_dataset"]
Expand All @@ -75,7 +75,7 @@ def render(
fig.add_trace(
go.Scatter(
x=dataset_data["layer"],
y=dataset_data["auroc_estimate"],
y=dataset_data[f"{sweep.metric_type}"],
mode="lines",
name=ensemble,
showlegend=False
Expand All @@ -95,7 +95,7 @@ def render(
legend=dict(
title="Ensembling",
),
title=f"AUROC Trend: {self.model_name}",
title=f"{sweep.metric_type} Trend: {self.model_name}",
)
if write:
fig.write_image(
Expand All @@ -114,7 +114,7 @@ class TransferEvalHeatmap:
"""Class for generating heatmaps for transfer evaluation results."""

layer: int
score_type: str = "auroc_estimate"
metric_type: str = ""
ensembling: str = "full"

def render(self, df: pd.DataFrame) -> go.Figure:
Expand All @@ -129,27 +129,28 @@ def render(self, df: pd.DataFrame) -> go.Figure:
model_name = df["eval_dataset"].iloc[0] # infer model name
# TODO: validate
pivot = pd.pivot_table(
df, values=self.score_type, index="eval_dataset", columns="train_dataset"
df, values=self.metric_type, index="eval_dataset", columns="train_dataset"
)

fig = px.imshow(pivot, color_continuous_scale="Viridis", text_auto=True)

fig.update_layout(
xaxis_title="Train Dataset",
yaxis_title="Transfer Dataset",
title=f"AUROC Score Heatmap: {model_name} | Layer {self.layer}",
title=f"{self.metric_type} Score Heatmap: {model_name} \
| Layer {self.layer}",
)

return fig


@dataclass
class TransferEvalTrend:
"""Class for generating line plots for the trend of AUROC scores in transfer
"""Class for generating line plots for the trend of metric scores in transfer
evaluation."""

dataset_names: list[str] | None
score_type: str = "auroc_estimate"
metric_type: str = ""

def render(self, df: pd.DataFrame) -> go.Figure:
"""Render the trend plot visualization.
Expand All @@ -164,14 +165,14 @@ def render(self, df: pd.DataFrame) -> go.Figure:
if self.dataset_names is not None:
df = self._filter_transfer_datasets(df, self.dataset_names)
pivot = pd.pivot_table(
df, values=self.score_type, index="layer", columns="eval_dataset"
df, values=self.metric_type, index="layer", columns="eval_dataset"
)

fig = px.line(pivot, color_discrete_sequence=px.colors.qualitative.Plotly)
fig.update_layout(
xaxis_title="Layer",
yaxis_title="AUROC Score",
title=f"AUROC Score Trend: {model_name}",
yaxis_title=f"{self.metric_type} Score",
title=f"{self.metric_type} Score Trend: {model_name}",
)

avg = pivot.mean(axis=1)
Expand Down Expand Up @@ -244,17 +245,16 @@ def render_and_save(
self,
sweep: "SweepVisualization",
dataset_names: list[str] | None = None,
score_type="auroc_estimate",
ensembling="full",
) -> None:
"""Render and save the visualization for the model.

Args:
sweep: The SweepVisualization instance.
dataset_names: List of dataset names to include in the visualization.
score_type: The type of score to display.
ensembling: The ensembling option to consider.
"""
metric_type = sweep.metric_type
df = self.df
model_name = self.model_name
layer_min, layer_max = df["layer"].min(), df["layer"].max()
Expand All @@ -264,10 +264,10 @@ def render_and_save(
for layer in range(layer_min, layer_max + 1):
filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)]
fig = TransferEvalHeatmap(
layer, score_type=score_type, ensembling=ensembling
layer, metric_type=metric_type, ensembling=ensembling
).render(filtered)
fig.write_image(file=model_path / f"{layer}.png")
fig = TransferEvalTrend(dataset_names).render(df)
fig = TransferEvalTrend(dataset_names, metric_type=metric_type).render(df)
fig.write_image(file=model_path / "transfer_eval_trend.png")

@staticmethod
Expand All @@ -288,6 +288,7 @@ class SweepVisualization:
path: Path
datasets: list[str]
models: dict[str, ModelVisualization]
metric_type: str

def model_names(self) -> list[str]:
"""Get the names of all models in the sweep.
Expand Down Expand Up @@ -323,7 +324,7 @@ def _get_model_paths(sweep_path: Path) -> list[Path]:
return folders

@classmethod
def collect(cls, sweep_path: Path) -> "SweepVisualization":
def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization":
"""Collect the evaluation data for a sweep.

Args:
Expand All @@ -348,7 +349,9 @@ def collect(cls, sweep_path: Path) -> "SweepVisualization":
}
df = pd.concat([model.df for model in models.values()], ignore_index=True)
datasets = list(df["eval_dataset"].unique())
return cls(sweep_name, df, sweep_viz_path, datasets, models)
return cls(
sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type
)

def render_and_save(self):
"""Render and save all visualizations for the sweep."""
Expand All @@ -368,14 +371,11 @@ def render_multiplots(self, write=False):
for model in self.models
]

def render_table(
self, score_type="auroc_estimate", display=True, write=False
) -> pd.DataFrame:
def render_table(self, display=True, write=False) -> pd.DataFrame:
"""Render and optionally write the score table.

Args:
layer: The layer number (from last layer) to include in the score table.
score_type: The type of score to include in the table.
display: Flag indicating whether to display the table to stdout.
write: Flag indicating whether to write the table to a file.

Expand All @@ -387,15 +387,15 @@ def render_table(
# For each model, we use the layer whose mean AUROC is the highest
best_layers, model_dfs = [], []
for _, model_df in df.groupby("model_name"):
best_layer = model_df.groupby("layer").auroc_estimate.mean().argmax()
best_layer = model_df.groupby("layer")[self.metric_type].mean().argmax()

best_layers.append(best_layer)
model_dfs.append(model_df[model_df["layer"] == best_layer])

pivot_table = pd.concat(model_dfs).pivot_table(
index="eval_dataset",
columns="model_name",
values=score_type,
values=self.metric_type,
margins=True,
margins_name="Mean",
)
Expand All @@ -416,14 +416,14 @@ def render_table(
console.print(table)

if write:
pivot_table.to_csv(f"score_table_{score_type}.csv")
pivot_table.to_csv(f"score_table_{self.metric_type}.csv")
return pivot_table


def visualize_sweep(sweep_path: Path):
def visualize_sweep(sweep_path: Path, metric_type: str):
"""Visualize a sweep by generating and saving the visualizations.

Args:
sweep_path: The path to the sweep data directory.
"""
SweepVisualization.collect(sweep_path).render_and_save()
SweepVisualization.collect(sweep_path, metric_type).render_and_save()
5 changes: 4 additions & 1 deletion elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Sweep:
visualize: bool = False
"""Whether to generate visualizations of the results of the sweep."""

metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
Expand Down Expand Up @@ -176,4 +179,4 @@ def execute(self):
eval.execute(highlight_color="green")

if self.visualize:
visualize_sweep(sweep_dir)
visualize_sweep(sweep_dir, self.metric_type)