Conversation
|
For our notes - here's a summary of the work done / next steps: # TorchGFN Multinode Scaling Analysis - Project State
## Overview
A comprehensive analysis framework for evaluating multinode scaling experiments in TorchGFN using Weights & Biases data. The tool visualizes how different community sizes and strategies affect mode discovery performance across hypergrid environments.
## Current Capabilities
### 📊 Data Pipeline
- **Wandb Integration**: Fetches runs from `torchgfn/torchgfn` project with robust timeout handling
- **Hierarchical Organization**: Environment → Community → Runs structure
- **Strategy Extraction**: Automatically identifies unique strategy configurations across all runs
### 🎨 Visualization System
#### Multi-Dimensional Encoding
| Dimension | Visual Encoding | Purpose |
|-----------|----------------|---------|
| Community Size | **Color** (tab20 colormap) | Compare scaling (Size 1, 2, 4, 8, 16, 32...) |
| Strategy | **Linestyle + Marker** | Distinguish experimental conditions |
| Run State | **Opacity** | Solid=finished, Faded=crashed/failed |
#### Plot Layout
- **70/30 split**: Main plot takes 70% of vertical space, legend takes 30%
- **Three-section legend**:
1. Community Sizes (horizontal row, color-coded)
2. Strategies (single column, linestyle + marker shown)
3. Run States (opacity explanation)
#### Visual Tuning
- Thin lines (`linewidth=1.2`) to reduce overlap
- Large markers (`markersize=9`) for shape distinction
- ~12 markers per line for visual clarity
- Legend shows full linestyle pattern (`handlelength=4`)
### 🏷️ Strategy Mapping System
```python
STRATEGY_MAPPING = {
"average_every=100_...": "Baseline",
"average_every=100_..._use_selective_averaging=True": "Selective Averaging",
"average_every=100_..._use_random_strategies=True": "Random Strategies",
# Add more mappings as needed
}
🖥️ Command Line Interface# Full analysis with plots
python multinode_scaling_analysis.py
# Quick mode: just print strategies for mapping
python multinode_scaling_analysis.py --print-strategies-onlyCurrent Strategy Mappings
File StructurePossible Next Steps for Analysis📈 Quantitative Analysis
🔬 Deeper Investigations
📊 Visualization Enhancements
🔧 Tool Improvements
📝 Documentation & Reporting
Technical Stack
Quick Reference# Key constants at top of script
STRATEGY_PARAMS = ['average_every', 'replacement_ratio', ...]
STRATEGY_MAPPING = {...} # Long-form → shorthand
# Key functions
fetch_wandb_runs() # Get data from wandb
analyze_communities_within_environments() # Main analysis loop
plot_communities_in_environment() # Generate plots
format_strategy_for_legend() # Uses STRATEGY_MAPPING |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #463 +/- ##
===========================================
+ Coverage 0.55% 74.38% +73.83%
===========================================
Files 48 47 -1
Lines 6845 6891 +46
Branches 802 825 +23
===========================================
+ Hits 38 5126 +5088
+ Misses 6806 1454 -5352
- Partials 1 311 +310
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
|
||
| # Show some examples | ||
| print("\nExamples of community vs environment groupings:") | ||
| example_runs = list(run_to_community.keys())[:8] # Show first 8 |
There was a problem hiding this comment.
The magic number 8 is used to limit example runs displayed. Consider defining this as a named constant for better maintainability.
| fig = plt.figure(figsize=(16, 14)) | ||
| gs = fig.add_gridspec(2, 1, height_ratios=[70, 30], hspace=0.08) |
There was a problem hiding this comment.
The magic numbers for figure sizing (16, 14), height ratios (70, 30), and spacing (0.08) should be defined as named constants for better maintainability and to make it easier to adjust the layout consistently.
| label = f"Size {community_size}, Strategy: {strategy_short} ({run_state})" | ||
|
|
||
| # Determine marker frequency based on data length (show ~10-15 markers per line) | ||
| markevery = max(1, len(steps) // 12) |
There was a problem hiding this comment.
The magic number 12 is used to calculate marker frequency. Consider defining this as a named constant (e.g., TARGET_MARKERS_PER_LINE = 12) for better maintainability.
| except Exception as e: | ||
| print(f" Error checking history: {e}") |
There was a problem hiding this comment.
The broad exception catch except Exception as e on line 227 silently swallows all exceptions and only prints them. Consider catching specific exceptions (e.g., wandb.errors.CommError, requests.exceptions.RequestException) or re-raising after logging to ensure critical errors are not hidden.
| def plot_communities_in_environment( # noqa: C901 | ||
| env_config_id, | ||
| community_data, | ||
| community_metadata, | ||
| strategy_linestyle_map, | ||
| strategy_marker_map, | ||
| ): | ||
| """Plot n_modes_found progression for communities within a single environment. | ||
|
|
||
| Uses color for community size, linestyle and markers for strategy. | ||
| Linestyle and marker mappings are passed in to ensure consistency across plots. | ||
| """ | ||
| if not community_data: | ||
| return | ||
|
|
||
| communities_with_data = {cid: data for cid, data in community_data.items() if data} | ||
|
|
||
| if len(communities_with_data) <= 1: | ||
| print(f"Only {len(communities_with_data)} communities with data - skipping plot") | ||
| return | ||
|
|
||
| # Create figure with explicit layout: 70% plot, 30% legend space | ||
| fig = plt.figure(figsize=(16, 14)) | ||
| gs = fig.add_gridspec(2, 1, height_ratios=[70, 30], hspace=0.08) | ||
| ax = fig.add_subplot(gs[0]) | ||
|
|
||
| # Extract unique community sizes | ||
| community_sizes = sorted( | ||
| set(community_metadata[cid]["size"] for cid in communities_with_data.keys()) | ||
| ) | ||
|
|
||
| # Color mapping for community sizes | ||
| size_colors = cm.tab20(np.linspace(0, 1, len(community_sizes))) | ||
| size_color_map = dict(zip(community_sizes, size_colors)) | ||
|
|
||
| # Style mapping for run states (secondary modifier) | ||
| state_alphas = {"finished": 1.0, "crashed": 0.7, "failed": 0.5, "running": 0.8} | ||
|
|
||
| max_y = 0 | ||
| legend_handles = [] | ||
| legend_labels = [] | ||
|
|
||
| for community_id in sorted(communities_with_data.keys()): | ||
| runs_data = communities_with_data[community_id] | ||
| if not runs_data: | ||
| continue | ||
|
|
||
| community_size = community_metadata[community_id]["size"] | ||
| community_strategy = community_metadata[community_id]["strategy"] | ||
|
|
||
| # Get color for community size | ||
| base_color = size_color_map[community_size] | ||
|
|
||
| # Get linestyle and marker for strategy (from global mappings) | ||
| base_linestyle = strategy_linestyle_map.get(community_strategy, "-") | ||
| base_marker = strategy_marker_map.get(community_strategy, "o") | ||
|
|
||
| for run_data in runs_data: | ||
| steps = run_data["steps"] | ||
| n_modes = run_data["n_modes_found"] | ||
| run_state = run_data["run_state"] | ||
|
|
||
| # Apply state modifier (alpha) | ||
| alpha = state_alphas.get(run_state, 0.6) | ||
|
|
||
| # Create label for legend | ||
| strategy_short = ( | ||
| community_strategy.replace("_", ", ") | ||
| if community_strategy | ||
| else "unknown" | ||
| ) | ||
| label = f"Size {community_size}, Strategy: {strategy_short} ({run_state})" | ||
|
|
||
| # Determine marker frequency based on data length (show ~10-15 markers per line) | ||
| markevery = max(1, len(steps) // 12) | ||
|
|
||
| (line,) = ax.plot( | ||
| steps, | ||
| n_modes, | ||
| linestyle=base_linestyle, | ||
| color=base_color, | ||
| alpha=alpha, | ||
| linewidth=1.2, | ||
| label=label, | ||
| marker=base_marker, | ||
| markersize=9, | ||
| markevery=markevery, | ||
| ) | ||
|
|
||
| # Only add to legend if we haven't seen this combination before | ||
| legend_key = ( | ||
| f"size_{community_size}_strategy_{community_strategy}_{run_state}" | ||
| ) | ||
| if legend_key not in legend_labels: | ||
| legend_handles.append(line) | ||
| legend_labels.append(legend_key) | ||
|
|
||
| max_y = max(max_y, max(n_modes) if n_modes else 0) | ||
|
|
||
| # Clean up the environment config ID for title | ||
| env_title = env_config_id.replace("_", ", ") | ||
| ax.set_title( | ||
| f"Mode Discovery: {env_title}\nCommunity Size (📏) × Strategy (📊) Analysis", | ||
| fontsize=14, | ||
| fontweight="bold", | ||
| ) | ||
| ax.set_xlabel("Iteration/Step", fontsize=12) | ||
| ax.set_ylabel("Number of Modes Found", fontsize=12) | ||
| ax.grid(True, alpha=0.3) | ||
|
|
||
| # Create separate legends for colors (sizes) and linestyles (strategies) | ||
| if legend_handles: | ||
| # Create color legend for community sizes | ||
| size_handles = [] | ||
| size_labels = [] | ||
| size_colors_used = set() | ||
|
|
||
| # Create linestyle legend for strategies | ||
| strategy_handles = [] | ||
| strategy_labels = [] | ||
| strategy_linestyles_used = set() | ||
|
|
||
| # First pass: collect all unique sizes and strategies (prefer finished, but include all) | ||
| all_sizes_seen = {} # size -> best_state (finished > running > crashed > failed) | ||
| all_strategies_seen = {} # strategy -> best_state | ||
|
|
||
| state_priority = {"finished": 0, "running": 1, "crashed": 2, "failed": 3} | ||
|
|
||
| for handle, label_key in zip(legend_handles, legend_labels): | ||
| parts = label_key.split("_strategy_") | ||
| if len(parts) != 2: | ||
| continue | ||
| size_part = parts[0].replace("size_", "") | ||
| strategy_state_part = parts[1] | ||
| strategy_state_parts = strategy_state_part.rsplit("_", 1) | ||
| if len(strategy_state_parts) != 2: | ||
| continue | ||
| strategy_part, state_part = strategy_state_parts | ||
|
|
||
| # Track best state for each size | ||
| if size_part not in all_sizes_seen: | ||
| all_sizes_seen[size_part] = state_part | ||
| elif state_priority.get(state_part, 99) < state_priority.get( | ||
| all_sizes_seen[size_part], 99 | ||
| ): | ||
| all_sizes_seen[size_part] = state_part | ||
|
|
||
| # Track best state for each strategy | ||
| if strategy_part not in all_strategies_seen: | ||
| all_strategies_seen[strategy_part] = state_part | ||
| elif state_priority.get(state_part, 99) < state_priority.get( | ||
| all_strategies_seen[strategy_part], 99 | ||
| ): | ||
| all_strategies_seen[strategy_part] = state_part | ||
|
|
||
| # Build size legend entries for ALL sizes | ||
| for size_part in all_sizes_seen.keys(): | ||
| if size_part not in size_colors_used: | ||
| size_key = f"Size {size_part}" | ||
| size_handles.append( | ||
| Line2D( | ||
| [0], | ||
| [0], | ||
| color=size_color_map[int(size_part)], | ||
| linewidth=3, | ||
| label=size_key, | ||
| ) | ||
| ) | ||
| size_labels.append(size_key) | ||
| size_colors_used.add(size_part) | ||
|
|
||
| # Build strategy legend entries for ALL strategies (with linestyle AND marker) | ||
| # Use longer line segment [0, 0.5, 1] so linestyle pattern is visible | ||
| for strategy_part in all_strategies_seen.keys(): | ||
| if ( | ||
| strategy_part not in strategy_linestyles_used | ||
| and strategy_part in strategy_linestyle_map | ||
| ): | ||
| strategy_key = format_strategy_for_legend(strategy_part) | ||
| strategy_handles.append( | ||
| Line2D( | ||
| [0, 0.5, 1], | ||
| [0, 0, 0], | ||
| color="black", | ||
| linestyle=strategy_linestyle_map[strategy_part], | ||
| marker=strategy_marker_map.get(strategy_part, "o"), | ||
| markersize=10, | ||
| linewidth=1.5, | ||
| label=strategy_key, | ||
| ) | ||
| ) | ||
| strategy_labels.append(strategy_key) | ||
| strategy_linestyles_used.add(strategy_part) | ||
|
|
||
| # Sort size labels in ascending order | ||
| size_order = sorted(size_labels, key=lambda x: int(x.split()[1])) | ||
| size_handles_sorted = [] | ||
| size_labels_sorted = [] | ||
| for label in size_order: | ||
| idx = size_labels.index(label) | ||
| size_handles_sorted.append(size_handles[idx]) | ||
| size_labels_sorted.append(size_labels[idx]) | ||
| size_handles = size_handles_sorted | ||
| size_labels = size_labels_sorted | ||
|
|
||
| # Create legend axes in the bottom gridspec slot | ||
| legend_ax = fig.add_subplot(gs[1]) | ||
| legend_ax.axis("off") | ||
|
|
||
| # Create THREE separate legends stacked vertically: | ||
| # 1. Sizes (horizontal, compact) at top | ||
| # 2. Strategies (one per line, full width) in middle | ||
| # 3. Run states (horizontal) at bottom | ||
|
|
||
| # Legend 1: Community Sizes (horizontal row) | ||
| if size_handles: | ||
| size_legend = legend_ax.legend( | ||
| size_handles, | ||
| [f"📏 {label}" for label in size_labels], | ||
| loc="upper center", | ||
| bbox_to_anchor=(0.5, 1.0), | ||
| fontsize=10, | ||
| title="Community Sizes", | ||
| title_fontsize=11, | ||
| ncol=len(size_handles), | ||
| frameon=True, | ||
| fancybox=True, | ||
| columnspacing=2.0, | ||
| handletextpad=0.5, | ||
| ) | ||
| legend_ax.add_artist(size_legend) | ||
|
|
||
| # Legend 2: Strategies (single column, full width for long labels) | ||
| # Use handlelength=4 to show linestyle pattern clearly alongside markers | ||
| if strategy_handles: | ||
| strategy_legend = legend_ax.legend( | ||
| strategy_handles, | ||
| [f"📊 {label}" for label in strategy_labels], | ||
| loc="upper center", | ||
| bbox_to_anchor=(0.5, 0.65), | ||
| fontsize=10, | ||
| title="Strategies", | ||
| title_fontsize=11, | ||
| ncol=1, | ||
| frameon=True, | ||
| fancybox=True, | ||
| handlelength=4, | ||
| handletextpad=0.8, | ||
| ) | ||
| legend_ax.add_artist(strategy_legend) | ||
|
|
||
| # Legend 3: Run states (horizontal row at bottom) | ||
| state_handles = [ | ||
| Line2D([0], [0], color="gray", linestyle="-", alpha=1.0, linewidth=3), | ||
| Line2D([0], [0], color="gray", linestyle="-", alpha=0.5, linewidth=3), | ||
| ] | ||
| state_labels = ["Solid opacity: finished", "Faded opacity: crashed/failed"] | ||
| legend_ax.legend( | ||
| state_handles, | ||
| state_labels, | ||
| loc="upper center", | ||
| bbox_to_anchor=(0.5, 0.15), | ||
| fontsize=9, | ||
| title="Run States", | ||
| title_fontsize=10, | ||
| ncol=2, | ||
| frameon=True, | ||
| fancybox=True, | ||
| columnspacing=3.0, | ||
| handletextpad=0.5, | ||
| ) | ||
|
|
||
| # Set y-axis limit with some padding | ||
| if max_y > 0: | ||
| ax.set_ylim(0, max_y * 1.1) | ||
|
|
||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
| # Print community comparison statistics with size/strategy breakdown | ||
| print(f"\nCommunity Performance Summary for {env_config_id}:") | ||
|
|
||
| # Group by size and strategy | ||
| size_strategy_stats = {} | ||
| for community_id in sorted(communities_with_data.keys()): | ||
| runs_data = communities_with_data[community_id] | ||
| metadata = community_metadata[community_id] | ||
|
|
||
| size = metadata["size"] | ||
| strategy = metadata["strategy"] or "unknown" | ||
|
|
||
| key = f"Size {size}, Strategy: {format_strategy_for_legend(strategy)}" | ||
|
|
||
| if key not in size_strategy_stats: | ||
| size_strategy_stats[key] = {"finished": [], "total": 0} | ||
|
|
||
| size_strategy_stats[key]["total"] += len(runs_data) | ||
| finished_runs = [r for r in runs_data if r["run_state"] == "finished"] | ||
| size_strategy_stats[key]["finished"].extend( | ||
| [r["max_modes"] for r in finished_runs] | ||
| ) | ||
|
|
||
| for group_key, stats in sorted(size_strategy_stats.items()): | ||
| finished_count = len(stats["finished"]) | ||
| total_count = stats["total"] | ||
|
|
||
| if finished_count > 0: | ||
| avg_max_modes = sum(stats["finished"]) / finished_count | ||
| min_max = min(stats["finished"]) | ||
| max_max = max(stats["finished"]) | ||
| print( | ||
| f" {group_key}: {finished_count}/{total_count} finished, " | ||
| f"avg max modes = {avg_max_modes:.1f} (range: {min_max:.0f}-{max_max:.0f})" | ||
| ) | ||
| else: | ||
| print(f" {group_key}: {total_count} runs, 0 finished") | ||
|
|
There was a problem hiding this comment.
The function plot_communities_in_environment has excessive complexity (901 lines). Consider breaking this down into smaller, more manageable functions for better maintainability. For example, separate the legend creation logic, plot styling setup, and data processing into distinct helper functions.
| if len(run_ids) <= 3: | ||
| print(f" Run IDs: {run_ids}") | ||
| else: | ||
| print(f" Run IDs: {run_ids[:3]}... ({len(run_ids)} total)") | ||
|
|
||
| # Check environment configuration status distribution | ||
| print("\nEnvironment configuration status analysis:") | ||
| for env_config_id, run_ids in sorted(env_config_runs.items()): | ||
| env_config_run_objects = [run for run in runs_list if run.id in run_ids] | ||
| states = [run.state for run in env_config_run_objects] | ||
| state_counts = pd.Series(states).value_counts() | ||
| print(f"- {env_config_id}: {dict(state_counts)}") | ||
|
|
||
| return run_to_env_config, env_config_runs, env_config_details | ||
|
|
||
|
|
||
| def create_hierarchical_structure(runs_list, run_to_env_config, run_to_community): | ||
| """ | ||
| Create hierarchical structure: Environment Groups → Community Groups → Runs | ||
|
|
||
| Returns: | ||
| - env_to_communities: dict mapping environment_config_id to list of community_ids | ||
| - community_to_runs: dict mapping community_id to list of run_ids | ||
| - env_community_runs: nested dict[env_config_id][community_id] = list of runs | ||
| """ | ||
| env_to_communities = {} | ||
| community_to_runs = {} | ||
| env_community_runs = {} | ||
|
|
||
| for run in runs_list: | ||
| env_config_id = run_to_env_config.get(run.id) | ||
| community_id = run_to_community.get(run.id) | ||
|
|
||
| if env_config_id and community_id: | ||
| # Build environment → communities mapping | ||
| if env_config_id not in env_to_communities: | ||
| env_to_communities[env_config_id] = set() | ||
| env_to_communities[env_config_id].add(community_id) | ||
|
|
||
| # Build community → runs mapping | ||
| if community_id not in community_to_runs: | ||
| community_to_runs[community_id] = [] | ||
| community_to_runs[community_id].append(run.id) | ||
|
|
||
| # Build nested env → community → runs mapping | ||
| if env_config_id not in env_community_runs: | ||
| env_community_runs[env_config_id] = {} | ||
| if community_id not in env_community_runs[env_config_id]: | ||
| env_community_runs[env_config_id][community_id] = [] | ||
| env_community_runs[env_config_id][community_id].append(run) | ||
|
|
||
| # Convert sets to sorted lists for consistency | ||
| for env_id in env_to_communities: | ||
| env_to_communities[env_id] = sorted(env_to_communities[env_id]) | ||
|
|
||
| print("\n=== HIERARCHICAL STRUCTURE ANALYSIS ===") | ||
| print(f"Environments: {len(env_to_communities)}") | ||
| print(f"Total Communities: {len(community_to_runs)}") | ||
|
|
||
| for env_id, communities in sorted(env_to_communities.items()): | ||
| print(f"\nEnvironment {env_id}:") | ||
| print(f" Communities: {len(communities)}") | ||
| for community_id in communities: | ||
| runs_in_community = len(community_to_runs[community_id]) | ||
| finished_runs = sum( | ||
| 1 | ||
| for run in env_community_runs[env_id][community_id] | ||
| if run.state == "finished" | ||
| ) | ||
| print( | ||
| f" {community_id}: {runs_in_community} runs ({finished_runs} finished)" | ||
| ) | ||
|
|
||
| return env_to_communities, community_to_runs, env_community_runs | ||
|
|
||
|
|
||
| def analyze_groups(runs_list): | ||
| """Analyze group structure from run data.""" | ||
| print("\n=== WANDB GROUP ANALYSIS ===") | ||
| run_groups = {} | ||
| group_runs = {} | ||
|
|
||
| for run in runs_list: | ||
| group_id = getattr(run, "group", None) | ||
| if group_id: | ||
| run_groups[run.id] = group_id | ||
| if group_id not in group_runs: | ||
| group_runs[group_id] = [] | ||
| group_runs[group_id].append(run.id) | ||
|
|
||
| print(f"Found {len(group_runs)} unique wandb groups:") | ||
| for group_id, run_ids in sorted(group_runs.items()): | ||
| print(f"- Group {group_id}: {len(run_ids)} runs") | ||
| # Show first few run IDs for this group | ||
| if len(run_ids) <= 3: | ||
| print(f" Run IDs: {run_ids}") | ||
| else: | ||
| print(f" Run IDs: {run_ids[:3]}... ({len(run_ids)} total)") |
There was a problem hiding this comment.
The pattern of checking run IDs and displaying them appears multiple times (lines 101-104, 196-198). Consider extracting this into a helper function like format_run_ids_display(run_ids: list, max_display: int = 3) -> str to reduce code duplication.
| warnings.filterwarnings("ignore") | ||
|
|
There was a problem hiding this comment.
Using a blanket warnings.filterwarnings("ignore") suppresses all warnings globally, which can hide important issues. Consider being more specific about which warnings to ignore (e.g., using category or module filters), or remove this if the warnings are not problematic.
| warnings.filterwarnings("ignore") |
younik
left a comment
There was a problem hiding this comment.
non-blocking as it is a script, but I share some of the comments of Copilot; I resolved the irrelevant ones
Description
Plotting tool for the multinode experiments.
Automatically scrapes wandb for results and plots them. To be expanded.