diff --git a/README.md b/README.md index 325c45f..fc4b6e3 100644 --- a/README.md +++ b/README.md @@ -1,147 +1,151 @@ # OpenEuroLLM CLI (oellm) -A package for running OELLM CLI workflows across multiple HPC clusters using SLURM job arrays and Singularity containers. +A lightweight CLI for scheduling LLM evaluations across multiple HPC clusters using SLURM job arrays and Singularity containers. -## Currently supported workflows -- Schedule evaluations on multiple models and tasks on all clusters ✅ `oellm schedule-eval ...` -- Restart failed evaluations (e.g., due to node failures) ✅ `oellm collect-results ... --reschedule true` +## Features -## Planned workflows -- Sync and download evaluation results from all clusters via a shared data layer -- Schedule training jobs on all clusters -- Schedule conversions from MegatronLM to HuggingFace +- **Schedule evaluations** on multiple models and tasks: `oellm schedule-eval` +- **Collect results** and check for missing evaluations: `oellm collect-results` +- **Task groups** for pre-defined evaluation suites with automatic dataset pre-downloading +- **Multi-cluster support** with auto-detection (Leonardo, LUMI, JURECA) -## Quick Example +## Quick Start -**Prerequisites:** -- install [uv](https://docs.astral.sh/uv/#installation) +**Prerequisites:** Install [uv](https://docs.astral.sh/uv/#installation) ```bash # Install the package uv tool install -p 3.12 git+https://github.com/OpenEuroLLM/oellm-cli.git -# Run evaluations on multiple models and tasks +# Run evaluations using a task group (recommended) oellm schedule-eval \ --models "microsoft/DialoGPT-medium,EleutherAI/pythia-160m" \ + --task_groups "open-sci-0.01" + +# Or specify individual tasks +oellm schedule-eval \ + --models "EleutherAI/pythia-160m" \ --tasks "hellaswag,mmlu" \ --n_shot 5 ``` This will automatically: - Detect your current HPC cluster (Leonardo, LUMI, or JURECA) -- Download and cache the specified models and datasets -- Generate a SLURM job array to evaluate all model-task combinations -- Submit the jobs with appropriate cluster-specific resource allocations +- Download and cache the specified models +- Pre-download datasets for known tasks (see warning below) +- Generate and submit a SLURM job array with appropriate cluster-specific resources + +## Task Groups + +Task groups are pre-defined evaluation suites in [`task-groups.yaml`](oellm/resources/task-groups.yaml). Each group specifies tasks, their n-shot settings, and HuggingFace dataset mappings. + +Available task groups: +- `open-sci-0.01` - Standard benchmarks (COPA, MMLU, HellaSwag, ARC, etc.) +- `belebele-eu-5-shot` - Belebele European language tasks +- `flores-200-eu-to-eng` / `flores-200-eng-to-eu` - Translation tasks +- `global-mmlu-eu` - Global MMLU in EU languages +- `mgsm-eu` - Multilingual GSM benchmarks +- `generic-multilingual` - XWinograd, XCOPA, XStoryCloze +- `include` - INCLUDE benchmarks -In case you meet HuggingFace quotas issues, make sure you are logged in by setting your `HF_TOKEN` and that you are part of [OpenEuroLLM](https://huggingface.co/OpenEuroLLM) organization. +Super groups combine multiple task groups: +- `oellm-multilingual` - All multilingual benchmarks combined -You can also directly schedule using a CSV file: ```bash -oellm schedule-eval --eval_csv_path custom_evals.csv +# Use a task group +oellm schedule-eval --models "model-name" --task_groups "open-sci-0.01" + +# Use multiple task groups +oellm schedule-eval --models "model-name" --task_groups "belebele-eu-5-shot,global-mmlu-eu" + +# Use a super group +oellm schedule-eval --models "model-name" --task_groups "oellm-multilingual" ``` -## Installation +## ⚠️ Dataset Pre-Download Warning -### JURECA/JSC Specifics +**Datasets are only automatically pre-downloaded for tasks defined in [`task-groups.yaml`](oellm/resources/task-groups.yaml).** + +If you use custom tasks via `--tasks` that are not in the task groups registry, the CLI will attempt to look them up but **cannot guarantee the datasets will be cached**. This may cause failures on compute nodes that don't have network access. + +**Recommendation:** Use `--task_groups` when possible, or ensure your custom task datasets are already cached in `$HF_HOME` before scheduling. -Due to the limit space in `$HOME` on JSC clusters, you must set these `uv` specific environment variables to avoid running out of space: +## Collecting Results + +After evaluations complete, collect results into a CSV: ```bash -export UV_CACHE_DIR="/.cache/uv-cache" -export UV_INSTALL_DIR="/.local" -export UV_PYTHON_INSTALL_DIR="/.local/share/uv/python" -export UV_TOOL_DIR="/.cache/uv-tool-cache" +# Basic collection +oellm collect-results /path/to/eval-output-dir + +# Check for missing evaluations and create a CSV for re-running them +oellm collect-results /path/to/eval-output-dir --check --output_csv results.csv ``` -You can set these variables in your `.bashrc` or `.zshrc` file, depending on your shell of preference. +The `--check` flag compares completed results against `jobs.csv` and outputs a `results_missing.csv` that can be used to re-schedule failed jobs: -E.g., I have a user-folder in the `synthlaion` project, so I set the following variables: ```bash -export UV_CACHE_DIR="/p/project1/synthlaion/$USER/.cache/uv-cache" -export UV_INSTALL_DIR="/p/project1/synthlaion/$USER/.local" -export UV_PYTHON_INSTALL_DIR="/p/project1/synthlaion/$USER/.local/share/uv/python" -export UV_TOOL_DIR="/p/project1/synthlaion/$USER/.cache/uv-tool-cache" +oellm schedule-eval --eval_csv_path results_missing.csv ``` -### General Installation +## CSV-Based Scheduling -Install directly from the git repository using uv: +For full control, provide a CSV file with columns: `model_path`, `task_path`, `n_shot`, and optionally `eval_suite`: ```bash -uv tool install -p 3.12 git+https://github.com/OpenEuroLLM/oellm-cli.git +oellm schedule-eval --eval_csv_path custom_evals.csv ``` -This makes the `oellm` command available globally in your shell. +## Installation + +### General Installation -If you've already installed the package, you can run the following command to update it: ```bash -uv tool upgrade oellm +uv tool install -p 3.12 git+https://github.com/OpenEuroLLM/oellm-cli.git ``` -If you had previously installed the package from a different source and would like to overwrite it, you can run the following command: +Update to latest: ```bash -uv tool install -p 3.12 git+https://github.com/OpenEuroLLM/oellm-cli.git --force +uv tool upgrade oellm ``` -## High-Level Evaluation Workflow - -The `oellm` package orchestrates distributed LLM evaluations through the following workflow: - -### 1. **Cluster Auto-Detection** -- Automatically detects the current HPC cluster based on hostname patterns -- Loads cluster-specific configurations from [`clusters.yaml`](oellm/resources/clusters.yaml) including: - - SLURM partition and account settings - - Shared storage paths for models, datasets, and results - - GPU allocation and queue limits - - Singularity container specifications - -### 2. **Resource Preparation** -- **Model Handling**: Processes both local model checkpoints and Hugging Face Hub models - - For local paths: Automatically discovers and expands training checkpoint directories - - For HF models: Pre-downloads to shared cache (`$HF_HOME`) for offline access on compute nodes -- **Dataset Caching**: Pre-downloads all evaluation datasets using lm-evaluation-harness TaskManager -- **Container Management**: Ensures the appropriate Singularity container is available for the target cluster - -### 3. **Job Generation & Scheduling** -- Creates a comprehensive CSV manifest of all model-task-shot combinations -- Generates a SLURM batch script from a template with cluster-specific parameters -- Submits a job array where each array task processes a subset of evaluations -- Respects queue limits and current user load to avoid overwhelming the scheduler - -### 4. **Distributed Execution** -- Each SLURM array job runs in a Singularity container with: - - GPU access (NVIDIA CUDA or AMD ROCm as appropriate) - - Mounted shared storage for models, datasets, and output - - Offline execution using pre-cached resources -- Uses `lm-evaluation-harness` for the actual model evaluation -- Outputs results as JSON files - -### 5. **Output Organization** -Results are organized in timestamped directories under `$EVAL_OUTPUT_DIR/$USER/`: -``` -2024-01-15-14-30-45/ -├── jobs.csv # Complete evaluation manifest -├── submit_evals.sbatch # Generated SLURM script -├── slurm_logs/ # SLURM output/error logs -└── results/ # Evaluation JSON outputs -``` +### JURECA/JSC Specifics + +Due to limited space in `$HOME` on JSC clusters, set these environment variables: -## Supported Clusters +```bash +export UV_CACHE_DIR="/p/project1//$USER/.cache/uv-cache" +export UV_INSTALL_DIR="/p/project1//$USER/.local" +export UV_PYTHON_INSTALL_DIR="/p/project1//$USER/.local/share/uv/python" +export UV_TOOL_DIR="/p/project1//$USER/.cache/uv-tool-cache" +``` -Currently supports three HPC clusters: +## Supported Clusters: +We support: Leonardo, Lumi, and Jureca -- **LEONARDO** - NVIDIA A100 GPUs (CUDA) -- **LUMI** - AMD MI250X GPUs (ROCm) -- **JURECA** - NVIDIA A100 GPUs (CUDA) +## CLI Options -Each cluster has pre-configured: -- Shared evaluation directories with appropriate quotas -- Optimized Singularity containers with evaluation dependencies -- Account and partition settings for the OpenEuroLLM project +```bash +oellm schedule-eval --help +``` -## Development and Testing -Run in download-only mode to prepare resources without submitting jobs: +## Development ```bash -oellm schedule-eval --models "EleutherAI/pythia-160m" --tasks "hellaswag" --n_shot 0 --download_only True +# Clone and install in dev mode +git clone https://github.com/OpenEuroLLM/oellm-cli.git +cd oellm-cli +uv sync --extra dev + +# Run dataset validation tests +uv run pytest tests/test_datasets.py -v + +# Download-only mode for testing +uv run oellm schedule-eval --models "EleutherAI/pythia-160m" --task_groups "open-sci-0.01" --download_only ``` + +## Troubleshooting + +**HuggingFace quota issues**: Ensure you're logged in with `HF_TOKEN` and are part of the [OpenEuroLLM](https://huggingface.co/OpenEuroLLM) organization. + +**Dataset download failures on compute nodes**: Use `--task_groups` for automatic dataset caching, or pre-download datasets manually before scheduling. diff --git a/oellm/main.py b/oellm/main.py index c1d7f5d..81e6680 100644 --- a/oellm/main.py +++ b/oellm/main.py @@ -1,4 +1,5 @@ import logging +import math import os import re import subprocess @@ -8,20 +9,21 @@ from pathlib import Path from string import Template -import numpy as np import pandas as pd from jsonargparse import auto_cli -from oellm.task_cache import clear_task_cache -from oellm.task_groups import _expand_task_groups +from oellm.task_groups import ( + _collect_dataset_specs, + _expand_task_groups, + _lookup_dataset_specs_for_tasks, +) from oellm.utils import ( _ensure_singularity_image, _expand_local_model_paths, _filter_warnings, _load_cluster_env, _num_jobs_in_queue, - _pre_download_lighteval_datasets, - _pre_download_task_datasets, + _pre_download_datasets_from_specs, _process_model_paths, _setup_logging, capture_third_party_output_from_kwarg, @@ -199,17 +201,24 @@ def schedule_evals( # Ensure that all datasets required by the tasks are cached locally to avoid # network access on compute nodes. if not skip_checks: - lm_eval_tasks = df[df["eval_suite"].isin({"lm-eval-harness"})][ - "task_path" - ].unique() - if len(lm_eval_tasks) > 0: - _pre_download_task_datasets( - lm_eval_tasks, trust_remote_code=trust_remote_code + dataset_specs = [] + if task_groups: + dataset_specs = _collect_dataset_specs( + [g.strip() for g in task_groups.split(",")] + ) + else: + # Look up individual tasks in task groups registry + all_tasks = df["task_path"].unique().tolist() + dataset_specs = _lookup_dataset_specs_for_tasks(all_tasks) + if not dataset_specs: + logging.info( + "No dataset specs found for tasks; skipping dataset pre-download" + ) + + if dataset_specs: + _pre_download_datasets_from_specs( + dataset_specs, trust_remote_code=trust_remote_code ) - # Pre-download LightEval datasets (best-effort, incremental support) - light_eval_tasks = df[df["eval_suite"].isin({"light-eval"})]["task_path"].unique() - if len(light_eval_tasks) > 0: - _pre_download_lighteval_datasets(light_eval_tasks) else: logging.info("Skipping dataset pre-download (--skip-checks enabled)") @@ -254,15 +263,15 @@ def schedule_evals( minutes_per_eval = 10 # Budget 10 minutes per eval total_minutes = total_evals * minutes_per_eval max_minutes_per_job = 18 * 60 # 18 hours - min_array_size_for_time = max(1, int(np.ceil(total_minutes / max_minutes_per_job))) + min_array_size_for_time = max(1, int(math.ceil(total_minutes / max_minutes_per_job))) desired_array_size = min(128, total_evals) if total_evals >= 128 else total_evals if desired_array_size < min_array_size_for_time: desired_array_size = min_array_size_for_time actual_array_size = min(remaining_queue_capacity, desired_array_size, total_evals) - evals_per_job = max(1, int(np.ceil(total_evals / actual_array_size))) + evals_per_job = max(1, int(math.ceil(total_evals / actual_array_size))) minutes_per_job = evals_per_job * minutes_per_eval minutes_with_margin = int(minutes_per_job * 1.2) - hours_with_margin = max(1, int(np.ceil(minutes_with_margin / 60))) + hours_with_margin = max(1, int(math.ceil(minutes_with_margin / 60))) hours_with_margin = max(hours_with_margin, 3) hours_with_margin = min(hours_with_margin, 23) time_limit = f"{hours_with_margin:02d}:59:00" @@ -633,7 +642,6 @@ def main(): { "schedule-eval": schedule_evals, "collect-results": collect_results, - "clean-cache": lambda: clear_task_cache(), }, as_positional=False, description="OELLM: Multi-cluster evaluation tool for language models", diff --git a/oellm/resources/task-groups.yaml b/oellm/resources/task-groups.yaml index 69ca6c8..fdf56d9 100644 --- a/oellm/resources/task-groups.yaml +++ b/oellm/resources/task-groups.yaml @@ -5,60 +5,102 @@ task_groups: tasks: - task: copa n_shots: [0] + dataset: super_glue + subset: copa - task: social_iqa n_shots: [0] + dataset: allenai/social_i_qa - task: openbookqa n_shots: [0] + dataset: allenai/openbookqa - task: lambada_openai n_shots: [0] + dataset: EleutherAI/lambada_openai - task: winogrande n_shots: [0] + dataset: allenai/winogrande + subset: winogrande_xl - task: mmlu n_shots: [5] + dataset: cais/mmlu + subset: all - task: hellaswag n_shots: [10] + dataset: Rowan/hellaswag - task: arc_easy n_shots: [10] + dataset: allenai/ai2_arc + subset: ARC-Easy - task: arc_challenge n_shots: [10] + dataset: allenai/ai2_arc + subset: ARC-Challenge - task: commonsense_qa n_shots: [10] + dataset: tau/commonsense_qa - task: piqa n_shots: [10] + dataset: ybisk/piqa - task: boolq n_shots: [10] + dataset: google/boolq belebele-eu-5-shot: description: "Belebele European language tasks" suite: lm-eval-harness n_shots: [5] + dataset: facebook/belebele tasks: - task: belebele_bul_Cyrl + subset: bul_Cyrl - task: belebele_hrv_Latn + subset: hrv_Latn - task: belebele_ces_Latn + subset: ces_Latn - task: belebele_dan_Latn + subset: dan_Latn - task: belebele_nld_Latn + subset: nld_Latn - task: belebele_eng_Latn + subset: eng_Latn - task: belebele_est_Latn + subset: est_Latn - task: belebele_fin_Latn + subset: fin_Latn - task: belebele_fra_Latn + subset: fra_Latn - task: belebele_deu_Latn + subset: deu_Latn - task: belebele_ell_Grek + subset: ell_Grek - task: belebele_hun_Latn + subset: hun_Latn - task: belebele_ita_Latn + subset: ita_Latn - task: belebele_lvs_Latn + subset: lvs_Latn - task: belebele_lit_Latn + subset: lit_Latn - task: belebele_mlt_Latn + subset: mlt_Latn - task: belebele_pol_Latn + subset: pol_Latn - task: belebele_por_Latn + subset: por_Latn - task: belebele_ron_Latn + subset: ron_Latn - task: belebele_slk_Latn + subset: slk_Latn - task: belebele_slv_Latn + subset: slv_Latn - task: belebele_spa_Latn + subset: spa_Latn - task: belebele_swe_Latn + subset: swe_Latn flores-200-eu-to-eng: description: "Flores 200 EU to English translation" suite: lighteval n_shots: [0] + dataset: facebook/flores tasks: - task: flores200:bul_Cyrl-eng_Latn - task: flores200:ces_Latn-eng_Latn @@ -87,6 +129,7 @@ task_groups: description: "Flores 200 English to EU translation" suite: lighteval n_shots: [0] + dataset: facebook/flores tasks: - task: flores200:eng_Latn-bul_Cyrl - task: flores200:eng_Latn-ces_Latn @@ -115,34 +158,58 @@ task_groups: description: "Global MMLU EU tasks" suite: lm-eval-harness n_shots: [5] + dataset: CohereForAI/Global-MMLU tasks: - task: global_mmlu_full_cs + subset: cs - task: global_mmlu_full_de + subset: de - task: global_mmlu_full_el + subset: el - task: global_mmlu_full_en + subset: en - task: global_mmlu_full_es + subset: es - task: global_mmlu_full_fr + subset: fr - task: global_mmlu_full_it + subset: it - task: global_mmlu_full_lt + subset: lt - task: global_mmlu_full_nl + subset: nl - task: global_mmlu_full_pl + subset: pl - task: global_mmlu_full_pt + subset: pt - task: global_mmlu_full_ro + subset: ro - task: global_mmlu_full_ru + subset: ru - task: global_mmlu_full_sr + subset: sr - task: global_mmlu_full_sv + subset: sv - task: global_mmlu_full_tr + subset: tr - task: global_mmlu_full_uk + subset: uk - task: global_mmlu_full_he + subset: he mgsm-eu: description: "EU Language GSM benchmarks in Aya Expanse" suite: lm-eval-harness n_shots: [5] + dataset: juletxara/mgsm tasks: - task: mgsm_native_cot_en + subset: en - task: mgsm_native_cot_de + subset: de - task: mgsm_native_cot_es + subset: es - task: mgsm_native_cot_fr + subset: fr generic-multilingual: description: "Generic multilingual benchmarks in Aya Expanse" @@ -150,39 +217,68 @@ task_groups: n_shots: [0] tasks: - task: xwinograd + dataset: Muennighoff/xwinograd - task: xcopa + dataset: cambridgeltl/xcopa - task: xstorycloze + dataset: juletxara/xstory_cloze include: description: "INCLUDE benchmarks in Aya Expanse" suite: lm-eval-harness n_shots: [0] + dataset: CohereForAI/include-base-44 tasks: - task: include_base_44_albanian + subset: Albanian - task: include_base_44_armenian + subset: Armenian - task: include_base_44_azerbaijani + subset: Azerbaijani - task: include_base_44_basque + subset: Basque - task: include_base_44_belarusian + subset: Belarusian - task: include_base_44_bulgarian + subset: Bulgarian - task: include_base_44_croatian + subset: Croatian - task: include_base_44_dutch + subset: Dutch - task: include_base_44_estonian + subset: Estonian - task: include_base_44_finnish + subset: Finnish - task: include_base_44_french + subset: French - task: include_base_44_georgian + subset: Georgian - task: include_base_44_german + subset: German - task: include_base_44_greek + subset: Greek - task: include_base_44_hungarian + subset: Hungarian - task: include_base_44_italian + subset: Italian - task: include_base_44_lithuanian + subset: Lithuanian - task: include_base_44_north macedonian + subset: North Macedonian - task: include_base_44_polish + subset: Polish - task: include_base_44_portuguese + subset: Portuguese - task: include_base_44_russian + subset: Russian - task: include_base_44_serbian + subset: Serbian - task: include_base_44_spanish + subset: Spanish - task: include_base_44_turkish + subset: Turkish - task: include_base_44_ukrainian + subset: Ukrainian super_groups: oellm-multilingual: diff --git a/oellm/task_cache.py b/oellm/task_cache.py deleted file mode 100644 index a320bee..0000000 --- a/oellm/task_cache.py +++ /dev/null @@ -1,330 +0,0 @@ -import json -from contextlib import contextmanager -from contextvars import ContextVar -from datetime import datetime -from pathlib import Path - -TASK_CACHE_TTL_DAYS = 30 - - -_CURRENT_CAPTURE_BUFFER: ContextVar[list[dict] | None] = ContextVar( - "_CURRENT_CAPTURE_BUFFER", default=None -) - - -def get_task_cache_file() -> Path: - return Path(__file__).resolve().parent / "resources" / "task_map_cache.json" - - -def load_task_cache() -> dict: - cache_file = get_task_cache_file() - if cache_file.exists(): - with open(cache_file) as f: - return json.load(f) or {} - return {} - - -def save_task_cache(cache: dict) -> None: - cache_file = get_task_cache_file() - with open(cache_file, "w") as f: - json.dump(cache, f, indent=2, sort_keys=True) - - -def clear_task_cache() -> None: - cache_file = get_task_cache_file() - with open(cache_file, "w") as f: - json.dump({}, f) - - -def task_cache_key(framework: str, task_id: str) -> str: - return f"{framework}::{task_id}" - - -def task_cache_is_fresh(entry: dict, ttl_days: int = TASK_CACHE_TTL_DAYS) -> bool: - ts = float(entry.get("ts", 0)) - age_days = (datetime.now().timestamp() - ts) / 86400.0 - return age_days >= 0 and age_days < float(ttl_days) - - -def task_cache_lookup( - framework: str, task_id: str, ttl_days: int = TASK_CACHE_TTL_DAYS -) -> bool: - cache = load_task_cache() - key = task_cache_key(framework, task_id) - entry = cache.get(key) - if not isinstance(entry, dict): - return False - return task_cache_is_fresh(entry, ttl_days) - - -def task_cache_mark_resolved(framework: str, task_id: str) -> None: - cache = load_task_cache() - key = task_cache_key(framework, task_id) - entry = cache.get(key) if isinstance(cache.get(key), dict) else {} - entry["ts"] = datetime.now().timestamp() - cache[key] = entry - save_task_cache(cache) - - -def task_cache_get_payload(framework: str, task_id: str) -> dict | None: - cache = load_task_cache() - key = task_cache_key(framework, task_id) - entry = cache.get(key) - if not isinstance(entry, dict): - return None - payload = entry.get("payload") - return payload if isinstance(payload, dict) else None - - -def task_cache_set_payload(framework: str, task_id: str, payload: dict) -> None: - cache = load_task_cache() - key = task_cache_key(framework, task_id) - entry: dict = cache.get(key) if isinstance(cache.get(key), dict) else {} # type: ignore[assignment] - entry["ts"] = datetime.now().timestamp() - entry["payload"] = payload - cache[key] = entry - save_task_cache(cache) - - -def _canonical_key(call: dict) -> tuple: - t = call.get("type") - if t == "load_dataset": - return ( - t, - call.get("path"), - call.get("name"), - call.get("split"), - call.get("revision"), - ) - if t == "snapshot_download": - return ( - t, - call.get("repo_id"), - call.get("repo_type"), - call.get("revision"), - ) - if t == "hf_hub_download": - return ( - t, - call.get("repo_id"), - call.get("filename"), - call.get("repo_type"), - call.get("revision"), - ) - return (str(t),) - - -def dedupe_calls(calls: list[dict]) -> list[dict]: - if not isinstance(calls, list): - return [] - best: dict[tuple, dict] = {} - for c in calls: - if not isinstance(c, dict): - continue - key = _canonical_key(c) - existing = best.get(key) - if existing is None: - best[key] = c - continue - # Prefer trust_remote_code=True for load_dataset - if c.get("type") == "load_dataset": - if bool(c.get("trust_remote_code")) and not bool( - existing.get("trust_remote_code") - ): - best[key] = c - # Optionally drop snapshot_download if matching load_dataset exists - filtered: list[dict] = [] - load_keys = { - ("load_dataset", k[1], k[2], k[3], k[4]) - for k in best.keys() - if k and k[0] == "load_dataset" - } - for k, v in best.items(): - if k and k[0] == "snapshot_download": - # derive comparable key shape: (type, repo_id, None, None, revision) - comparable = ("load_dataset", k[1], None, None, k[3]) - if comparable in load_keys: - continue - filtered.append(v) - return filtered - - -@contextmanager -def capture_hf_dataset_calls(): - captured: list[dict] = [] - _buffer_token = _CURRENT_CAPTURE_BUFFER.set(captured) - - import datasets as _ds # type: ignore - import huggingface_hub as _hfh # type: ignore - - _orig_load_dataset = _ds.load_dataset - _orig_snapshot_download = _hfh.snapshot_download - _orig_hf_hub_download = _hfh.hf_hub_download - - def _load_dataset_proxy(path, *args, **kwargs): # noqa: ANN001 - name = ( - kwargs.get("name") - if "name" in kwargs - else (args[0] if len(args) > 0 else None) - ) - data_files = ( - kwargs.get("data_files") - if "data_files" in kwargs - else (args[1] if len(args) > 1 else None) - ) - split = ( - kwargs.get("split") - if "split" in kwargs - else (args[2] if len(args) > 2 else None) - ) - trust_remote_code = kwargs.get("trust_remote_code") - revision = kwargs.get("revision") - buf = _CURRENT_CAPTURE_BUFFER.get() - if isinstance(buf, list): - buf.append( - { - "type": "load_dataset", - "path": path, - "name": name, - "data_files": data_files, - "split": split, - "revision": revision, - "trust_remote_code": trust_remote_code, - } - ) - return _orig_load_dataset(path, *args, **kwargs) - - def _snapshot_download_proxy(*args, **kwargs): # noqa: ANN001 - repo_id = ( - kwargs.get("repo_id") - if "repo_id" in kwargs - else (args[0] if len(args) > 0 else None) - ) - repo_type = ( - kwargs.get("repo_type") - if "repo_type" in kwargs - else (args[1] if len(args) > 1 else None) - ) - revision = ( - kwargs.get("revision") - if "revision" in kwargs - else (args[2] if len(args) > 2 else None) - ) - buf = _CURRENT_CAPTURE_BUFFER.get() - if isinstance(buf, list): - buf.append( - { - "type": "snapshot_download", - "repo_id": repo_id, - "repo_type": repo_type, - "revision": revision, - } - ) - return _orig_snapshot_download(*args, **kwargs) - - def _hf_hub_download_proxy(*args, **kwargs): # noqa: ANN001 - repo_id = ( - kwargs.get("repo_id") - if "repo_id" in kwargs - else (args[0] if len(args) > 0 else None) - ) - filename = ( - kwargs.get("filename") - if "filename" in kwargs - else (args[1] if len(args) > 1 else None) - ) - repo_type = ( - kwargs.get("repo_type") - if "repo_type" in kwargs - else (args[2] if len(args) > 2 else None) - ) - revision = ( - kwargs.get("revision") - if "revision" in kwargs - else (args[3] if len(args) > 3 else None) - ) - buf = _CURRENT_CAPTURE_BUFFER.get() - if isinstance(buf, list): - buf.append( - { - "type": "hf_hub_download", - "repo_id": repo_id, - "filename": filename, - "repo_type": repo_type, - "revision": revision, - } - ) - return _orig_hf_hub_download(*args, **kwargs) - - _ds.load_dataset = _load_dataset_proxy # type: ignore[assignment] - _hfh.snapshot_download = _snapshot_download_proxy # type: ignore[assignment] - _hfh.hf_hub_download = _hf_hub_download_proxy # type: ignore[assignment] - - try: - yield captured - finally: - _ds.load_dataset = _orig_load_dataset # type: ignore[assignment] - _hfh.snapshot_download = _orig_snapshot_download # type: ignore[assignment] - _hfh.hf_hub_download = _orig_hf_hub_download # type: ignore[assignment] - _CURRENT_CAPTURE_BUFFER.reset(_buffer_token) - - -def prewarm_from_payload(payload: dict | None, *, trust_remote_code: bool = True) -> None: - if not isinstance(payload, dict): - return - calls = payload.get("calls") - if not isinstance(calls, list): - return - - from datasets import load_dataset # type: ignore - from huggingface_hub import hf_hub_download, snapshot_download # type: ignore - - for call in calls: - if not isinstance(call, dict): - continue - # Unified prewarm log message - if call.get("type") == "load_dataset": - path = call.get("path") - name = call.get("name") - else: - repo_id = call.get("repo_id") - filename = call.get("filename") - - if call.get("type") == "snapshot_download": - repo_id = call.get("repo_id") - if isinstance(repo_id, str) and repo_id: - snapshot_download( - repo_id=repo_id, - repo_type=call.get("repo_type") or "dataset", - revision=call.get("revision"), - ) - continue - if call.get("type") == "hf_hub_download": - repo_id = call.get("repo_id") - filename = call.get("filename") - if isinstance(repo_id, str) and isinstance(filename, str): - hf_hub_download( - repo_id=repo_id, - filename=filename, - repo_type=call.get("repo_type"), - revision=call.get("revision"), - ) - continue - path = call.get("path") - name = call.get("name") - data_files = call.get("data_files") - split = call.get("split") - revision = call.get("revision") - trc = call.get("trust_remote_code", trust_remote_code) - kwargs: dict = {} - if name is not None: - kwargs["name"] = name - if data_files is not None: - kwargs["data_files"] = data_files - if revision is not None: - kwargs["revision"] = revision - kwargs["trust_remote_code"] = bool(trc) - if split is not None: - load_dataset(path, split=split, **kwargs) - else: - load_dataset(path, **kwargs) diff --git a/oellm/task_groups.py b/oellm/task_groups.py index 73c7d35..b99f083 100644 --- a/oellm/task_groups.py +++ b/oellm/task_groups.py @@ -5,10 +5,18 @@ import yaml +@dataclass +class DatasetSpec: + repo_id: str + subset: str | None = None + + @dataclass class _Task: name: str n_shots: list[int] | None = None + dataset: str | None = None + subset: str | None = None @dataclass @@ -18,6 +26,7 @@ class TaskGroup: suite: str description: str n_shots: list[int] | None = None + dataset: str | None = None def __post_init__(self): for task in self.tasks: @@ -27,6 +36,8 @@ def __post_init__(self): raise ValueError( f"N_shots is not set for task {task.name} and no default n_shots is set for the task group: {self.name}" ) + if task.dataset is None and self.dataset is not None: + task.dataset = self.dataset @classmethod def from_dict(cls, name: str, data: dict) -> "TaskGroup": @@ -34,7 +45,16 @@ def from_dict(cls, name: str, data: dict) -> "TaskGroup": for task_data in data["tasks"]: task_name = task_data["task"] task_n_shots = task_data.get("n_shots") - tasks.append(_Task(name=task_name, n_shots=task_n_shots)) + task_dataset = task_data.get("dataset") + task_subset = task_data.get("subset") + tasks.append( + _Task( + name=task_name, + n_shots=task_n_shots, + dataset=task_dataset, + subset=task_subset, + ) + ) return cls( name=name, @@ -42,6 +62,7 @@ def from_dict(cls, name: str, data: dict) -> "TaskGroup": suite=data["suite"], description=data["description"], n_shots=data.get("n_shots"), + dataset=data.get("dataset"), ) @@ -140,3 +161,70 @@ def _expand_task_groups(group_names: Iterable[str]) -> list[TaskGroupResult]: ) return results + + +def _collect_dataset_specs(group_names: Iterable[str]) -> list[DatasetSpec]: + parsed = _parse_task_groups([str(n).strip() for n in group_names if str(n).strip()]) + + specs: list[DatasetSpec] = [] + seen: set[tuple[str, str | None]] = set() + + def add_spec(dataset: str | None, subset: str | None): + if dataset is None: + return + key = (dataset, subset) + if key not in seen: + seen.add(key) + specs.append(DatasetSpec(repo_id=dataset, subset=subset)) + + for _, group in parsed.items(): + if isinstance(group, TaskGroup): + for t in group.tasks: + add_spec(t.dataset, t.subset) + else: + for g in group.task_groups: + for t in g.tasks: + add_spec(t.dataset, t.subset) + + return specs + + +def _build_task_dataset_map() -> dict[str, DatasetSpec]: + """Build a mapping from task names to their dataset specs from all task groups.""" + data = ( + yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) or {} + ) + + all_group_names = list(data.get("task_groups", {}).keys()) + parsed = _parse_task_groups(all_group_names) + + task_map: dict[str, DatasetSpec] = {} + + for _, group in parsed.items(): + if isinstance(group, TaskGroup): + for t in group.tasks: + if t.dataset and t.name not in task_map: + task_map[t.name] = DatasetSpec(repo_id=t.dataset, subset=t.subset) + + return task_map + + +def _lookup_dataset_specs_for_tasks(task_names: Iterable[str]) -> list[DatasetSpec]: + """Look up dataset specs for individual task names from the task groups registry.""" + task_map = _build_task_dataset_map() + + specs: list[DatasetSpec] = [] + seen: set[tuple[str, str | None]] = set() + + for task_name in task_names: + task_name = str(task_name).strip() + if not task_name: + continue + spec = task_map.get(task_name) + if spec: + key = (spec.repo_id, spec.subset) + if key not in seen: + seen.add(key) + specs.append(spec) + + return specs diff --git a/oellm/utils.py b/oellm/utils.py index ef2ed5f..7e1fa01 100644 --- a/oellm/utils.py +++ b/oellm/utils.py @@ -15,16 +15,6 @@ from rich.console import Console from rich.logging import RichHandler -from oellm.task_cache import ( - capture_hf_dataset_calls, - dedupe_calls, - prewarm_from_payload, - task_cache_get_payload, - task_cache_lookup, - task_cache_mark_resolved, - task_cache_set_payload, -) - _RICH_CONSOLE: Console | None = None @@ -263,158 +253,32 @@ def _process_model_paths(models: Iterable[str]): ) -def _pre_download_task_datasets( - tasks: Iterable[str], trust_remote_code: bool = True +def _pre_download_datasets_from_specs( + specs: Iterable, trust_remote_code: bool = True ) -> None: - processed: set[str] = set() - - misses: list[str] = [] - console = get_console() - with console.status("Checking lm-eval datasets…", spinner="dots") as status: - cache_hits = 0 - for task_name in tasks: - if not isinstance(task_name, str) or task_name in processed: - continue - processed.add(task_name) - if task_cache_lookup("lm-eval", task_name): - cache_hits += 1 - status.update( - f"Checking lm-eval datasets… {cache_hits} cached, {len(misses)} to prepare" - ) - continue - misses.append(task_name) - status.update( - f"Checking lm-eval datasets… {cache_hits} cached, {len(misses)} to prepare" - ) + from datasets import load_dataset - if not misses: - with console.status( - f"Using cached lm-eval datasets for {len(processed)} tasks…", - spinner="dots", - ) as status: - for task_name in processed: - if task_cache_lookup("lm-eval", task_name): - status.update(f"Loading cached dataset for '{task_name}'…") - prewarm_from_payload( - task_cache_get_payload("lm-eval", task_name), - trust_remote_code=trust_remote_code, - ) + specs_list = list(specs) + if not specs_list: return - from datasets import DownloadMode # type: ignore - from lm_eval.tasks import TaskManager # type: ignore - - tm = TaskManager() - - with console.status( - f"Preparing lm-eval datasets… {len(misses)} remaining", - spinner="dots", - ) as status: - for idx, task_name in enumerate(misses, 1): - status.update(f"Preparing dataset for '{task_name}' ({idx}/{len(misses)})") - - task_config = { - "task": task_name, - "dataset_kwargs": {"trust_remote_code": trust_remote_code}, - } - - with capture_hf_dataset_calls() as captured_calls: - task_objects = tm.load_config(task_config) - - stack = [task_objects] - while stack: - current = stack.pop() - if isinstance(current, dict): - stack.extend(current.values()) - continue - if hasattr(current, "download") and callable(current.download): - try: - current.download( - download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS - ) # type: ignore[arg-type] - except TypeError as e: - logging.error( - f"Failed to download dataset for task '{task_name}' with download_mode=REUSE_DATASET_IF_EXISTS: {e}" - ) - current.download() # type: ignore[misc] - - if captured_calls: - payload = {"calls": dedupe_calls(captured_calls)} - task_cache_set_payload("lm-eval", task_name, payload) - task_cache_mark_resolved("lm-eval", task_name) - logging.debug(f"Finished dataset preparation for task '{task_name}'.") - - -def _pre_download_lighteval_datasets(tasks: Iterable[str]) -> None: - seen: set[str] = set() - misses: list[str] = [] - tasks = [str(task).strip() for task in tasks] console = get_console() - with console.status("Checking lighteval datasets…", spinner="dots") as status: - cache_hits = 0 - for task in tasks: - if not task or task in seen: - continue - seen.add(task) - if task_cache_lookup("lighteval", task): - cache_hits += 1 - status.update( - f"Checking lighteval datasets… {cache_hits} cached, {len(misses)} to prepare" - ) - continue - misses.append(task) - status.update( - f"Checking lighteval datasets… {cache_hits} cached, {len(misses)} to prepare" - ) - - if not misses: - with console.status( - f"Using cached lighteval datasets for {len(seen)} tasks…", - spinner="dots", - ): - for task in seen: - if task_cache_lookup("lighteval", task): - prewarm_from_payload( - task_cache_get_payload("lighteval", task), - trust_remote_code=True, - ) - return with console.status( - f"Preparing lighteval datasets… {len(misses)} remaining", + f"Downloading datasets… {len(specs_list)} datasets", spinner="dots", ) as status: - for idx, task in enumerate(misses, 1): - status.update(f"Preparing dataset for '{task}' ({idx}/{len(misses)})") - with capture_hf_dataset_calls() as captured_calls: - from lighteval.tasks.lighteval_task import LightevalTask - from lighteval.tasks.registry import ( - TRUNCATE_FEW_SHOTS_DEFAULTS, - Registry, - ) - - reg = Registry(custom_tasks="lighteval.tasks.multilingual.tasks") - truncate_default = int(TRUNCATE_FEW_SHOTS_DEFAULTS) - - spec = task - if "|" not in spec: - spec = f"lighteval|{spec}|0|{truncate_default}" - elif spec.count("|") == 1: - spec = f"{spec}|0|{truncate_default}" - elif spec.count("|") == 2: - spec = f"{spec}|{truncate_default}" - - configs = reg.get_tasks_configs(spec) - task_dict = reg.get_tasks_from_configs(configs) - LightevalTask.load_datasets(task_dict) - - payload = ( - {"calls": dedupe_calls(captured_calls)} - if captured_calls - else {"calls": []} + for idx, spec in enumerate(specs_list, 1): + label = f"{spec.repo_id}" + (f"/{spec.subset}" if spec.subset else "") + status.update(f"Downloading '{label}' ({idx}/{len(specs_list)})") + + load_dataset( + spec.repo_id, + name=spec.subset, + trust_remote_code=trust_remote_code, ) - task_cache_set_payload("lighteval", task, payload) - task_cache_mark_resolved("lighteval", task) + + logging.debug(f"Finished downloading dataset '{label}'.") @contextmanager @@ -483,7 +347,7 @@ def filtered_module_debug(msg, *args, **kwargs): return orig_module_debug(msg, *args, **kwargs) return None - builtins.print = filtered_print + builtins.print = filtered_print # type: ignore logging.Logger.info = filtered_logger_info # type: ignore[assignment] logging.Logger.debug = filtered_logger_debug # type: ignore[assignment] logging.info = filtered_module_info # type: ignore[assignment] diff --git a/pyproject.toml b/pyproject.toml index 42855e0..6a26ece 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,6 @@ dependencies = [ "jsonargparse", "datasets", "rich", - "torch", - "lm-eval", - "lighteval[extended_tasks,multilingual] @ git+https://github.com/huggingface/lighteval.git@63424f4e795ecc577b90646381b374af3a627978", - "pydantic<2.12", "huggingface_hub", "pyyaml", ] @@ -35,19 +31,6 @@ module-name = "oellm" module-root = "" include = ["oellm/resources/*"] -[tool.uv.sources] -torch = [ - { index = "pytorch-cpu" }, -] -torchvision = [ - { index = "pytorch-cpu" }, -] - -[[tool.uv.index]] -name = "pytorch-cpu" -url = "https://download.pytorch.org/whl/cpu" -explicit = true - [tool.ruff] line-length = 90 target-version = "py312" diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..7248724 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,113 @@ +import sys +from importlib.resources import files + +import pytest +import yaml +from datasets import get_dataset_config_names +from huggingface_hub import dataset_info + +from oellm.task_groups import DatasetSpec, TaskGroup, _parse_task_groups + + +def collect_all_dataset_specs() -> list[DatasetSpec]: + """Collect all unique dataset specs from all task groups.""" + data = ( + yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) or {} + ) + + all_group_names = list(data.get("task_groups", {}).keys()) + list( + data.get("super_groups", {}).keys() + ) + parsed = _parse_task_groups(all_group_names) + + specs: list[DatasetSpec] = [] + seen: set[tuple[str, str | None]] = set() + + def add_spec(dataset: str | None, subset: str | None): + if dataset is None: + return + key = (dataset, subset) + if key not in seen: + seen.add(key) + specs.append(DatasetSpec(repo_id=dataset, subset=subset)) + + for _, group in parsed.items(): + if isinstance(group, TaskGroup): + for t in group.tasks: + add_spec(t.dataset, t.subset) + else: + for g in group.task_groups: + for t in g.tasks: + add_spec(t.dataset, t.subset) + + return specs + + +def check_dataset_exists(spec: DatasetSpec) -> tuple[bool, str]: + """Check if a dataset exists on HuggingFace.""" + label = f"{spec.repo_id}" + (f"/{spec.subset}" if spec.subset else "") + + info = dataset_info(spec.repo_id) + if info is None: + return False, f"Dataset repo '{spec.repo_id}' not found on HuggingFace" + + if spec.subset: + configs = get_dataset_config_names(spec.repo_id, trust_remote_code=True) + if spec.subset not in configs: + return ( + False, + f"Subset '{spec.subset}' not found in {spec.repo_id}. Available: {configs[:10]}{'...' if len(configs) > 10 else ''}", + ) + + return True, f"OK: {label}" + + +ALL_SPECS = collect_all_dataset_specs() + + +@pytest.mark.parametrize( + "spec", + ALL_SPECS, + ids=[f"{s.repo_id}/{s.subset}" if s.subset else s.repo_id for s in ALL_SPECS], +) +def test_dataset_exists(spec: DatasetSpec): + """Test that each dataset specified in task-groups.yaml exists on HuggingFace.""" + success, message = check_dataset_exists(spec) + assert success, message + + +def main(): + print("Collecting dataset specs from task-groups.yaml...") + specs = collect_all_dataset_specs() + print(f"Found {len(specs)} unique dataset specs\n") + + failed = [] + passed = [] + + for spec in specs: + label = f"{spec.repo_id}" + (f"/{spec.subset}" if spec.subset else "") + print(f"Checking {label}... ", end="", flush=True) + + success, message = check_dataset_exists(spec) + if success: + print("✓") + passed.append(spec) + else: + print(f"✗ - {message}") + failed.append((spec, message)) + + print(f"\n{'=' * 60}") + print(f"Results: {len(passed)} passed, {len(failed)} failed") + + if failed: + print("\nFailed datasets:") + for spec, msg in failed: + print(f" - {spec.repo_id}/{spec.subset}: {msg}") + sys.exit(1) + else: + print("\nAll datasets verified successfully!") + sys.exit(0) + + +if __name__ == "__main__": + main()