Skip to content

helper script for plotting#463

Open
josephdviviano wants to merge 1 commit intomasterfrom
multinode_analysis
Open

helper script for plotting#463
josephdviviano wants to merge 1 commit intomasterfrom
multinode_analysis

Conversation

@josephdviviano
Copy link
Collaborator

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

Plotting tool for the multinode experiments.

Automatically scrapes wandb for results and plots them. To be expanded.

@josephdviviano
Copy link
Collaborator Author

For our notes - here's a summary of the work done / next steps:

# TorchGFN Multinode Scaling Analysis - Project State

## Overview
A comprehensive analysis framework for evaluating multinode scaling experiments in TorchGFN using Weights & Biases data. The tool visualizes how different community sizes and strategies affect mode discovery performance across hypergrid environments.

## Current Capabilities

### 📊 Data Pipeline
- **Wandb Integration**: Fetches runs from `torchgfn/torchgfn` project with robust timeout handling
- **Hierarchical Organization**: Environment → Community → Runs structure
- **Strategy Extraction**: Automatically identifies unique strategy configurations across all runs

### 🎨 Visualization System

#### Multi-Dimensional Encoding
| Dimension | Visual Encoding | Purpose |
|-----------|----------------|---------|
| Community Size | **Color** (tab20 colormap) | Compare scaling (Size 1, 2, 4, 8, 16, 32...) |
| Strategy | **Linestyle + Marker** | Distinguish experimental conditions |
| Run State | **Opacity** | Solid=finished, Faded=crashed/failed |

#### Plot Layout
- **70/30 split**: Main plot takes 70% of vertical space, legend takes 30%
- **Three-section legend**:
  1. Community Sizes (horizontal row, color-coded)
  2. Strategies (single column, linestyle + marker shown)
  3. Run States (opacity explanation)

#### Visual Tuning
- Thin lines (`linewidth=1.2`) to reduce overlap
- Large markers (`markersize=9`) for shape distinction
- ~12 markers per line for visual clarity
- Legend shows full linestyle pattern (`handlelength=4`)

### 🏷️ Strategy Mapping System

```python
STRATEGY_MAPPING = {
    "average_every=100_...": "Baseline",
    "average_every=100_..._use_selective_averaging=True": "Selective Averaging",
    "average_every=100_..._use_random_strategies=True": "Random Strategies",
    # Add more mappings as needed
}
  • Consistent encoding: Same linestyle/marker for each strategy across all environment plots
  • Quick discovery mode: --print-strategies-only flag to list all strategies without generating plots

🖥️ Command Line Interface

# Full analysis with plots
python multinode_scaling_analysis.py

# Quick mode: just print strategies for mapping
python multinode_scaling_analysis.py --print-strategies-only

Current Strategy Mappings

Long-form ID Shorthand
average_every=100_..._use_selective_averaging=False Baseline
average_every=100_..._use_selective_averaging=True Selective Averaging
average_every=100_..._use_random_strategies=True Random Strategies
average_every=100_..._use_random_strategies=True_use_selective_averaging=True Selective Averaging & Random Strategies
average_every=16384000_... Baseline (16M steps Averaging?)
average_every=4294967296_... Baseline (4B steps Averaging?)

File Structure

tutorials/notebooks/
├── multinode_scaling_analysis.py    # Main analysis script (~1010 lines)
└── multinode_scaling_analysis.ipynb # Jupyter notebook companion

Possible Next Steps for Analysis

📈 Quantitative Analysis

  1. Convergence Speed Metrics

    • Time/iterations to reach X% of max modes
    • Compare convergence rates across strategies and sizes
    • Statistical significance testing (t-tests, ANOVA)
  2. Scaling Efficiency

    • Plot modes-per-agent vs community size
    • Identify diminishing returns threshold
    • Calculate parallelization efficiency
  3. Strategy Effectiveness Ranking

    • Aggregate performance across all environments
    • Rank strategies by average mode discovery rate
    • Identify best strategy per environment type

🔬 Deeper Investigations

  1. Environment Difficulty Analysis

    • Which environments are hardest (lowest max modes)?
    • Does strategy effectiveness vary by environment difficulty?
    • Correlation between environment parameters and performance
  2. Failure Analysis

    • Why do some runs crash/fail?
    • Is there a pattern (certain sizes, strategies, environments)?
    • Time-to-failure analysis
  3. Variance Analysis

    • How consistent are results within the same size/strategy?
    • Which configurations have highest variance?
    • Bootstrap confidence intervals

📊 Visualization Enhancements

  1. Summary Dashboard

    • Heatmap: Strategy × Size → Performance
    • Bar charts comparing strategies aggregated across environments
    • Box plots showing distribution of max modes
  2. Normalized Comparisons

    • Normalize by environment's theoretical max modes
    • Compare "fraction of modes found" instead of raw counts
    • Time-normalized curves (modes per hour)
  3. Interactive Plots

    • Plotly/Bokeh version for zooming/hovering
    • Filter by environment, strategy, or size interactively

🔧 Tool Improvements

  1. Caching Layer

    • Cache wandb data locally to speed up reruns
    • Incremental updates for new runs only
  2. Export Capabilities

    • Save plots as publication-quality PDFs
    • Export summary statistics to CSV
    • Generate LaTeX tables automatically
  3. Configuration File

    • Move STRATEGY_MAPPING to external YAML/JSON
    • Allow environment filtering via config
    • Customizable color/marker schemes

📝 Documentation & Reporting

  1. Automated Report Generation

    • Markdown/HTML report with key findings
    • Include best/worst performing configurations
    • Trend analysis summary
  2. Experiment Recommendations

    • Based on current results, suggest next experiments
    • Identify under-explored regions of parameter space

Technical Stack

  • Language: Python 3.12
  • Data: pandas, numpy
  • Visualization: matplotlib, seaborn
  • Experiment Tracking: Weights & Biases API
  • CLI: argparse

Quick Reference

# Key constants at top of script
STRATEGY_PARAMS = ['average_every', 'replacement_ratio', ...]
STRATEGY_MAPPING = {...}  # Long-form → shorthand

# Key functions
fetch_wandb_runs()                        # Get data from wandb
analyze_communities_within_environments() # Main analysis loop
plot_communities_in_environment()         # Generate plots
format_strategy_for_legend()              # Uses STRATEGY_MAPPING

@codecov
Copy link

codecov bot commented Jan 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.38%. Comparing base (a47bf73) to head (5b13588).
⚠️ Report is 35 commits behind head on master.

Additional details and impacted files
@@             Coverage Diff             @@
##           master     #463       +/-   ##
===========================================
+ Coverage    0.55%   74.38%   +73.83%     
===========================================
  Files          48       47        -1     
  Lines        6845     6891       +46     
  Branches      802      825       +23     
===========================================
+ Hits           38     5126     +5088     
+ Misses       6806     1454     -5352     
- Partials        1      311      +310     
Flag Coverage Δ
unittests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.


# Show some examples
print("\nExamples of community vs environment groupings:")
example_runs = list(run_to_community.keys())[:8] # Show first 8
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 8 is used to limit example runs displayed. Consider defining this as a named constant for better maintainability.

Copilot uses AI. Check for mistakes.
Comment on lines +796 to +797
fig = plt.figure(figsize=(16, 14))
gs = fig.add_gridspec(2, 1, height_ratios=[70, 30], hspace=0.08)
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic numbers for figure sizing (16, 14), height ratios (70, 30), and spacing (0.08) should be defined as named constants for better maintainability and to make it easier to adjust the layout consistently.

Copilot uses AI. Check for mistakes.
label = f"Size {community_size}, Strategy: {strategy_short} ({run_state})"

# Determine marker frequency based on data length (show ~10-15 markers per line)
markevery = max(1, len(steps) // 12)
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 12 is used to calculate marker frequency. Consider defining this as a named constant (e.g., TARGET_MARKERS_PER_LINE = 12) for better maintainability.

Copilot uses AI. Check for mistakes.
Comment on lines +227 to +228
except Exception as e:
print(f" Error checking history: {e}")
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad exception catch except Exception as e on line 227 silently swallows all exceptions and only prints them. Consider catching specific exceptions (e.g., wandb.errors.CommError, requests.exceptions.RequestException) or re-raising after logging to ensure critical errors are not hidden.

Copilot uses AI. Check for mistakes.
Comment on lines +774 to +1090
def plot_communities_in_environment( # noqa: C901
env_config_id,
community_data,
community_metadata,
strategy_linestyle_map,
strategy_marker_map,
):
"""Plot n_modes_found progression for communities within a single environment.

Uses color for community size, linestyle and markers for strategy.
Linestyle and marker mappings are passed in to ensure consistency across plots.
"""
if not community_data:
return

communities_with_data = {cid: data for cid, data in community_data.items() if data}

if len(communities_with_data) <= 1:
print(f"Only {len(communities_with_data)} communities with data - skipping plot")
return

# Create figure with explicit layout: 70% plot, 30% legend space
fig = plt.figure(figsize=(16, 14))
gs = fig.add_gridspec(2, 1, height_ratios=[70, 30], hspace=0.08)
ax = fig.add_subplot(gs[0])

# Extract unique community sizes
community_sizes = sorted(
set(community_metadata[cid]["size"] for cid in communities_with_data.keys())
)

# Color mapping for community sizes
size_colors = cm.tab20(np.linspace(0, 1, len(community_sizes)))
size_color_map = dict(zip(community_sizes, size_colors))

# Style mapping for run states (secondary modifier)
state_alphas = {"finished": 1.0, "crashed": 0.7, "failed": 0.5, "running": 0.8}

max_y = 0
legend_handles = []
legend_labels = []

for community_id in sorted(communities_with_data.keys()):
runs_data = communities_with_data[community_id]
if not runs_data:
continue

community_size = community_metadata[community_id]["size"]
community_strategy = community_metadata[community_id]["strategy"]

# Get color for community size
base_color = size_color_map[community_size]

# Get linestyle and marker for strategy (from global mappings)
base_linestyle = strategy_linestyle_map.get(community_strategy, "-")
base_marker = strategy_marker_map.get(community_strategy, "o")

for run_data in runs_data:
steps = run_data["steps"]
n_modes = run_data["n_modes_found"]
run_state = run_data["run_state"]

# Apply state modifier (alpha)
alpha = state_alphas.get(run_state, 0.6)

# Create label for legend
strategy_short = (
community_strategy.replace("_", ", ")
if community_strategy
else "unknown"
)
label = f"Size {community_size}, Strategy: {strategy_short} ({run_state})"

# Determine marker frequency based on data length (show ~10-15 markers per line)
markevery = max(1, len(steps) // 12)

(line,) = ax.plot(
steps,
n_modes,
linestyle=base_linestyle,
color=base_color,
alpha=alpha,
linewidth=1.2,
label=label,
marker=base_marker,
markersize=9,
markevery=markevery,
)

# Only add to legend if we haven't seen this combination before
legend_key = (
f"size_{community_size}_strategy_{community_strategy}_{run_state}"
)
if legend_key not in legend_labels:
legend_handles.append(line)
legend_labels.append(legend_key)

max_y = max(max_y, max(n_modes) if n_modes else 0)

# Clean up the environment config ID for title
env_title = env_config_id.replace("_", ", ")
ax.set_title(
f"Mode Discovery: {env_title}\nCommunity Size (📏) × Strategy (📊) Analysis",
fontsize=14,
fontweight="bold",
)
ax.set_xlabel("Iteration/Step", fontsize=12)
ax.set_ylabel("Number of Modes Found", fontsize=12)
ax.grid(True, alpha=0.3)

# Create separate legends for colors (sizes) and linestyles (strategies)
if legend_handles:
# Create color legend for community sizes
size_handles = []
size_labels = []
size_colors_used = set()

# Create linestyle legend for strategies
strategy_handles = []
strategy_labels = []
strategy_linestyles_used = set()

# First pass: collect all unique sizes and strategies (prefer finished, but include all)
all_sizes_seen = {} # size -> best_state (finished > running > crashed > failed)
all_strategies_seen = {} # strategy -> best_state

state_priority = {"finished": 0, "running": 1, "crashed": 2, "failed": 3}

for handle, label_key in zip(legend_handles, legend_labels):
parts = label_key.split("_strategy_")
if len(parts) != 2:
continue
size_part = parts[0].replace("size_", "")
strategy_state_part = parts[1]
strategy_state_parts = strategy_state_part.rsplit("_", 1)
if len(strategy_state_parts) != 2:
continue
strategy_part, state_part = strategy_state_parts

# Track best state for each size
if size_part not in all_sizes_seen:
all_sizes_seen[size_part] = state_part
elif state_priority.get(state_part, 99) < state_priority.get(
all_sizes_seen[size_part], 99
):
all_sizes_seen[size_part] = state_part

# Track best state for each strategy
if strategy_part not in all_strategies_seen:
all_strategies_seen[strategy_part] = state_part
elif state_priority.get(state_part, 99) < state_priority.get(
all_strategies_seen[strategy_part], 99
):
all_strategies_seen[strategy_part] = state_part

# Build size legend entries for ALL sizes
for size_part in all_sizes_seen.keys():
if size_part not in size_colors_used:
size_key = f"Size {size_part}"
size_handles.append(
Line2D(
[0],
[0],
color=size_color_map[int(size_part)],
linewidth=3,
label=size_key,
)
)
size_labels.append(size_key)
size_colors_used.add(size_part)

# Build strategy legend entries for ALL strategies (with linestyle AND marker)
# Use longer line segment [0, 0.5, 1] so linestyle pattern is visible
for strategy_part in all_strategies_seen.keys():
if (
strategy_part not in strategy_linestyles_used
and strategy_part in strategy_linestyle_map
):
strategy_key = format_strategy_for_legend(strategy_part)
strategy_handles.append(
Line2D(
[0, 0.5, 1],
[0, 0, 0],
color="black",
linestyle=strategy_linestyle_map[strategy_part],
marker=strategy_marker_map.get(strategy_part, "o"),
markersize=10,
linewidth=1.5,
label=strategy_key,
)
)
strategy_labels.append(strategy_key)
strategy_linestyles_used.add(strategy_part)

# Sort size labels in ascending order
size_order = sorted(size_labels, key=lambda x: int(x.split()[1]))
size_handles_sorted = []
size_labels_sorted = []
for label in size_order:
idx = size_labels.index(label)
size_handles_sorted.append(size_handles[idx])
size_labels_sorted.append(size_labels[idx])
size_handles = size_handles_sorted
size_labels = size_labels_sorted

# Create legend axes in the bottom gridspec slot
legend_ax = fig.add_subplot(gs[1])
legend_ax.axis("off")

# Create THREE separate legends stacked vertically:
# 1. Sizes (horizontal, compact) at top
# 2. Strategies (one per line, full width) in middle
# 3. Run states (horizontal) at bottom

# Legend 1: Community Sizes (horizontal row)
if size_handles:
size_legend = legend_ax.legend(
size_handles,
[f"📏 {label}" for label in size_labels],
loc="upper center",
bbox_to_anchor=(0.5, 1.0),
fontsize=10,
title="Community Sizes",
title_fontsize=11,
ncol=len(size_handles),
frameon=True,
fancybox=True,
columnspacing=2.0,
handletextpad=0.5,
)
legend_ax.add_artist(size_legend)

# Legend 2: Strategies (single column, full width for long labels)
# Use handlelength=4 to show linestyle pattern clearly alongside markers
if strategy_handles:
strategy_legend = legend_ax.legend(
strategy_handles,
[f"📊 {label}" for label in strategy_labels],
loc="upper center",
bbox_to_anchor=(0.5, 0.65),
fontsize=10,
title="Strategies",
title_fontsize=11,
ncol=1,
frameon=True,
fancybox=True,
handlelength=4,
handletextpad=0.8,
)
legend_ax.add_artist(strategy_legend)

# Legend 3: Run states (horizontal row at bottom)
state_handles = [
Line2D([0], [0], color="gray", linestyle="-", alpha=1.0, linewidth=3),
Line2D([0], [0], color="gray", linestyle="-", alpha=0.5, linewidth=3),
]
state_labels = ["Solid opacity: finished", "Faded opacity: crashed/failed"]
legend_ax.legend(
state_handles,
state_labels,
loc="upper center",
bbox_to_anchor=(0.5, 0.15),
fontsize=9,
title="Run States",
title_fontsize=10,
ncol=2,
frameon=True,
fancybox=True,
columnspacing=3.0,
handletextpad=0.5,
)

# Set y-axis limit with some padding
if max_y > 0:
ax.set_ylim(0, max_y * 1.1)

plt.tight_layout()
plt.show()

# Print community comparison statistics with size/strategy breakdown
print(f"\nCommunity Performance Summary for {env_config_id}:")

# Group by size and strategy
size_strategy_stats = {}
for community_id in sorted(communities_with_data.keys()):
runs_data = communities_with_data[community_id]
metadata = community_metadata[community_id]

size = metadata["size"]
strategy = metadata["strategy"] or "unknown"

key = f"Size {size}, Strategy: {format_strategy_for_legend(strategy)}"

if key not in size_strategy_stats:
size_strategy_stats[key] = {"finished": [], "total": 0}

size_strategy_stats[key]["total"] += len(runs_data)
finished_runs = [r for r in runs_data if r["run_state"] == "finished"]
size_strategy_stats[key]["finished"].extend(
[r["max_modes"] for r in finished_runs]
)

for group_key, stats in sorted(size_strategy_stats.items()):
finished_count = len(stats["finished"])
total_count = stats["total"]

if finished_count > 0:
avg_max_modes = sum(stats["finished"]) / finished_count
min_max = min(stats["finished"])
max_max = max(stats["finished"])
print(
f" {group_key}: {finished_count}/{total_count} finished, "
f"avg max modes = {avg_max_modes:.1f} (range: {min_max:.0f}-{max_max:.0f})"
)
else:
print(f" {group_key}: {total_count} runs, 0 finished")

Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function plot_communities_in_environment has excessive complexity (901 lines). Consider breaking this down into smaller, more manageable functions for better maintainability. For example, separate the legend creation logic, plot styling setup, and data processing into distinct helper functions.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +198
if len(run_ids) <= 3:
print(f" Run IDs: {run_ids}")
else:
print(f" Run IDs: {run_ids[:3]}... ({len(run_ids)} total)")

# Check environment configuration status distribution
print("\nEnvironment configuration status analysis:")
for env_config_id, run_ids in sorted(env_config_runs.items()):
env_config_run_objects = [run for run in runs_list if run.id in run_ids]
states = [run.state for run in env_config_run_objects]
state_counts = pd.Series(states).value_counts()
print(f"- {env_config_id}: {dict(state_counts)}")

return run_to_env_config, env_config_runs, env_config_details


def create_hierarchical_structure(runs_list, run_to_env_config, run_to_community):
"""
Create hierarchical structure: Environment Groups → Community Groups → Runs

Returns:
- env_to_communities: dict mapping environment_config_id to list of community_ids
- community_to_runs: dict mapping community_id to list of run_ids
- env_community_runs: nested dict[env_config_id][community_id] = list of runs
"""
env_to_communities = {}
community_to_runs = {}
env_community_runs = {}

for run in runs_list:
env_config_id = run_to_env_config.get(run.id)
community_id = run_to_community.get(run.id)

if env_config_id and community_id:
# Build environment → communities mapping
if env_config_id not in env_to_communities:
env_to_communities[env_config_id] = set()
env_to_communities[env_config_id].add(community_id)

# Build community → runs mapping
if community_id not in community_to_runs:
community_to_runs[community_id] = []
community_to_runs[community_id].append(run.id)

# Build nested env → community → runs mapping
if env_config_id not in env_community_runs:
env_community_runs[env_config_id] = {}
if community_id not in env_community_runs[env_config_id]:
env_community_runs[env_config_id][community_id] = []
env_community_runs[env_config_id][community_id].append(run)

# Convert sets to sorted lists for consistency
for env_id in env_to_communities:
env_to_communities[env_id] = sorted(env_to_communities[env_id])

print("\n=== HIERARCHICAL STRUCTURE ANALYSIS ===")
print(f"Environments: {len(env_to_communities)}")
print(f"Total Communities: {len(community_to_runs)}")

for env_id, communities in sorted(env_to_communities.items()):
print(f"\nEnvironment {env_id}:")
print(f" Communities: {len(communities)}")
for community_id in communities:
runs_in_community = len(community_to_runs[community_id])
finished_runs = sum(
1
for run in env_community_runs[env_id][community_id]
if run.state == "finished"
)
print(
f" {community_id}: {runs_in_community} runs ({finished_runs} finished)"
)

return env_to_communities, community_to_runs, env_community_runs


def analyze_groups(runs_list):
"""Analyze group structure from run data."""
print("\n=== WANDB GROUP ANALYSIS ===")
run_groups = {}
group_runs = {}

for run in runs_list:
group_id = getattr(run, "group", None)
if group_id:
run_groups[run.id] = group_id
if group_id not in group_runs:
group_runs[group_id] = []
group_runs[group_id].append(run.id)

print(f"Found {len(group_runs)} unique wandb groups:")
for group_id, run_ids in sorted(group_runs.items()):
print(f"- Group {group_id}: {len(run_ids)} runs")
# Show first few run IDs for this group
if len(run_ids) <= 3:
print(f" Run IDs: {run_ids}")
else:
print(f" Run IDs: {run_ids[:3]}... ({len(run_ids)} total)")
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern of checking run IDs and displaying them appears multiple times (lines 101-104, 196-198). Consider extracting this into a helper function like format_run_ids_display(run_ids: list, max_display: int = 3) -> str to reduce code duplication.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +21
warnings.filterwarnings("ignore")

Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a blanket warnings.filterwarnings("ignore") suppresses all warnings globally, which can hide important issues. Consider being more specific about which warnings to ignore (e.g., using category or module filters), or remove this if the warnings are not problematic.

Suggested change
warnings.filterwarnings("ignore")

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking as it is a script, but I share some of the comments of Copilot; I resolved the irrelevant ones

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants