From 302a08a521c71ab04144225a667666b4292bca87 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 12 Aug 2025 12:00:46 -0500 Subject: [PATCH 01/16] mostly working new tests draft --- test_new/__init__.py | 72 +++++ test_new/conftest.py | 135 ++++++++ test_new/framework.py | 166 ++++++++++ test_new/opt_adam.py | 25 ++ test_new/opt_adamw.py | 26 ++ test_new/opt_adan.py | 74 +++++ test_new/opt_anyadam.py | 109 +++++++ test_new/opt_lion.py | 50 +++ test_new/opt_radam.py | 25 ++ test_new/opt_ranger.py | 34 ++ test_new/opt_sgd.py | 75 +++++ test_new/opt_stableadamw.py | 25 ++ test_new/optimizer_tests.py | 282 +++++++++++++++++ test_new/pytest_integration.py | 155 +++++++++ test_new/test_optimizers.py | 556 +++++++++++++++++++++++++++++++++ 15 files changed, 1809 insertions(+) create mode 100644 test_new/__init__.py create mode 100644 test_new/conftest.py create mode 100644 test_new/framework.py create mode 100644 test_new/opt_adam.py create mode 100644 test_new/opt_adamw.py create mode 100644 test_new/opt_adan.py create mode 100644 test_new/opt_anyadam.py create mode 100644 test_new/opt_lion.py create mode 100644 test_new/opt_radam.py create mode 100644 test_new/opt_ranger.py create mode 100644 test_new/opt_sgd.py create mode 100644 test_new/opt_stableadamw.py create mode 100644 test_new/optimizer_tests.py create mode 100644 test_new/pytest_integration.py create mode 100644 test_new/test_optimizers.py diff --git a/test_new/__init__.py b/test_new/__init__.py new file mode 100644 index 0000000..b4e3a8f --- /dev/null +++ b/test_new/__init__.py @@ -0,0 +1,72 @@ +""" +Unified optimizer test framework. + +This module provides a simplified, dataclass-based approach to optimizer testing +that replaces the complex OptimizerSpec/OptimizerVariant architecture. +""" + +__version__ = "1.0.0" + +# Core framework components +from .framework import BaseParams, OptimizerTest, ToleranceConfig + +# Test discovery and registry +from .optimizer_tests import ( + ALL_OPTIMIZER_TESTS, + auto_generate_variants, + discover_optimizer_tests, + get_all_optimizer_names, + get_all_variant_names, + get_test_by_name, + get_test_count, + get_tests_by_optimizer, + get_tests_by_variant, + print_test_summary, +) + +# Pytest integration +from .pytest_integration import ( + create_marked_backends, + create_marked_device_types, + create_marked_dtypes, + create_marked_optimizer_tests, + create_float32_only_dtypes, + create_gpu_only_device_types, + create_test_matrix, + get_backend_marks, + get_device_marks, + get_dtype_marks, + get_optimizer_marks, + print_mark_summary, +) + +__all__ = [ + # Core framework + "BaseParams", + "OptimizerTest", + "ToleranceConfig", + # Test discovery + "ALL_OPTIMIZER_TESTS", + "auto_generate_variants", + "discover_optimizer_tests", + "get_all_optimizer_names", + "get_all_variant_names", + "get_test_by_name", + "get_test_count", + "get_tests_by_optimizer", + "get_tests_by_variant", + "print_test_summary", + # Pytest integration + "create_marked_backends", + "create_marked_device_types", + "create_marked_dtypes", + "create_marked_optimizer_tests", + "create_float32_only_dtypes", + "create_gpu_only_device_types", + "create_test_matrix", + "get_backend_marks", + "get_device_marks", + "get_dtype_marks", + "get_optimizer_marks", + "print_mark_summary", +] diff --git a/test_new/conftest.py b/test_new/conftest.py new file mode 100644 index 0000000..5fb7f39 --- /dev/null +++ b/test_new/conftest.py @@ -0,0 +1,135 @@ +"""Pytest configuration and fixtures for the unified optimizer test framework. + +This module provides pytest configuration, custom mark registration, and fixtures +for running optimizer tests across different devices, dtypes, and backends. +""" + +import pytest +import torch +from packaging import version + +from .optimizer_tests import get_all_optimizer_names + + +def pytest_configure(config): + """Configure pytest with custom marks for optimizer testing.""" + + # Register device marks + config.addinivalue_line("markers", "cpu: mark test to run on CPU") + config.addinivalue_line("markers", "gpu: mark test to run on GPU") + + # Register dtype marks + config.addinivalue_line("markers", "float32: mark test to run with float32 dtype") + config.addinivalue_line("markers", "bfloat16: mark test to run with bfloat16 dtype") + + # Register backend marks + config.addinivalue_line("markers", "torch: mark test to run with torch backend") + config.addinivalue_line("markers", "triton: mark test to run with triton backend") + + optimizer_names = get_all_optimizer_names() + for optimizer_name in optimizer_names: + config.addinivalue_line("markers", f"{optimizer_name}: mark test for {optimizer_name} optimizer") + + +# Check for minimum PyTorch version for Triton support +MIN_TORCH_2_6 = version.parse("2.6.0") +CURRENT_TORCH_VERSION = version.parse(torch.__version__.split("+")[0]) # Remove any +cu118 suffix +HAS_TRITON_SUPPORT = CURRENT_TORCH_VERSION >= MIN_TORCH_2_6 + + +@pytest.fixture(scope="session") +def gpu_device(): + """Provide GPU device for testing if available. + + Returns: + torch.device: GPU device (cuda, xpu, or mps) if available, otherwise None. + """ + if torch.cuda.is_available(): + return torch.device("cuda") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return None + + +@pytest.fixture(scope="session") +def has_gpu(gpu_device): + """Check if GPU is available for testing. + + Returns: + bool: True if GPU device is available, False otherwise. + """ + return gpu_device is not None + + +@pytest.fixture(scope="session") +def has_triton(): + """Check if Triton backend is available. + + Returns: + bool: True if Triton is supported (PyTorch >= 2.6), False otherwise. + """ + return HAS_TRITON_SUPPORT + + +@pytest.fixture +def tolerance_config(): + """Provide default tolerance configuration for numerical comparisons. + + Returns: + dict: Default tolerance settings for different dtypes. + """ + from .framework import ToleranceConfig + + return { + torch.float32: ToleranceConfig(rtol=1e-5, atol=1e-8), + torch.bfloat16: ToleranceConfig(rtol=1e-3, atol=1e-5), # More relaxed for bfloat16 + } + + +@pytest.fixture +def cpu_device(): + """Provide CPU device for testing. + + Returns: + torch.device: CPU device. + """ + return torch.device("cpu") + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to add automatic skipping for unavailable resources.""" + + # Skip GPU tests if no GPU is available + if not torch.cuda.is_available() and not (hasattr(torch, "xpu") and torch.xpu.is_available()): + skip_gpu = pytest.mark.skip(reason="GPU not available") + for item in items: + if "gpu" in item.keywords: + item.add_marker(skip_gpu) + + # Skip Triton tests if not supported + if not HAS_TRITON_SUPPORT: + skip_triton = pytest.mark.skip(reason=f"Triton requires PyTorch >= {MIN_TORCH_2_6}, got {CURRENT_TORCH_VERSION}") + for item in items: + if "triton" in item.keywords: + item.add_marker(skip_triton) + + +def pytest_runtest_setup(item): + """Setup hook to perform additional test skipping based on marks.""" + + # Skip GPU tests on CPU-only systems + if "gpu" in item.keywords: + gpu_available = ( + torch.cuda.is_available() + or (hasattr(torch, "xpu") and torch.xpu.is_available()) + or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + ) + if not gpu_available: + pytest.skip("GPU not available") + + # Skip Triton tests if not supported + if "triton" in item.keywords and not HAS_TRITON_SUPPORT: + pytest.skip(f"Triton requires PyTorch >= {MIN_TORCH_2_6}, got {CURRENT_TORCH_VERSION}") diff --git a/test_new/framework.py b/test_new/framework.py new file mode 100644 index 0000000..bb71b62 --- /dev/null +++ b/test_new/framework.py @@ -0,0 +1,166 @@ +"""Core framework components for the unified optimizer test system. + +This module provides the foundational dataclasses and utilities for defining +and executing optimizer tests in a type-safe, self-contained manner. +""" + +import inspect +from copy import deepcopy +from dataclasses import asdict, dataclass, field +from typing import Any + +import torch +from optimi.optimizer import OptimiOptimizer +from torch.optim.optimizer import Optimizer + + +@dataclass +class ToleranceConfig: + """Tolerance configuration for numerical comparisons.""" + + atol: float = 1e-6 + rtol: float = 1e-5 + max_error_rate: float = 0.0005 + equal_nan: bool = False + + +@dataclass +class BaseParams: + """Base class for all optimizer parameters with common fields.""" + + lr: float = 1e-3 + weight_decay: float = 0.0 + decouple_wd: bool = False + decouple_lr: bool = False + triton: bool = False + + def _filter_kwargs_for_class(self, optimizer_class: type) -> dict[str, Any]: + """Filter parameters based on optimizer signature inspection.""" + if optimizer_class is None: + return {} + + # Get the optimizer's __init__ signature + sig = inspect.signature(optimizer_class.__init__) + valid_params = set(sig.parameters.keys()) - {"self"} + + # Filter our parameters to only include those accepted by the optimizer + return {k: v for k, v in asdict(self).items() if k in valid_params} + + def to_optimi_kwargs(self, optimi_class: type) -> dict[str, Any]: + """Convert to kwargs for optimi optimizer.""" + return self._filter_kwargs_for_class(optimi_class) + + def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: + """Convert to kwargs for reference optimizer.""" + return self._filter_kwargs_for_class(reference_class) + + +@dataclass +class OptimizerTest: + """Complete self-contained optimizer test case.""" + + # Test identification + name: str # "adam_base", "sgd_momentum", etc. + + # Optimizer classes and parameters + optimi_class: type[OptimiOptimizer] + optimi_params: BaseParams + reference_class: type[Optimizer] + reference_params: BaseParams | None = None + + # Optional fully decoupled reference + fully_decoupled_reference: Optimizer | None = None + + # Test behavior overrides (optional) + test_decoupled_wd: bool = True + skip_tests: list[str] = field(default_factory=list) + any_precision: bool = False + custom_iterations: dict[str, int] | None = None + custom_tolerances: dict[torch.dtype, ToleranceConfig] | None = None + # Optional constraints + only_dtypes: list[torch.dtype] | None = None + + def __post_init__(self): + """Post-initialization checks and adjustments.""" + if self.reference_params is None: + self.reference_params = deepcopy(self.optimi_params) + + if self.custom_tolerances is None: + self.custom_tolerances = {} + if self.custom_tolerances.get(torch.float32, None) is None: + self.custom_tolerances[torch.float32] = ToleranceConfig() + if self.custom_tolerances.get(torch.bfloat16, None) is None: + self.custom_tolerances[torch.bfloat16] = ToleranceConfig(atol=1e-3, rtol=1e-2, max_error_rate=0.01) + if self.custom_tolerances.get(torch.float16, None) is None: + self.custom_tolerances[torch.float16] = ToleranceConfig(atol=1e-4, rtol=1e-3, max_error_rate=0.01) + + @property + def optimizer_name(self) -> str: + """Extract optimizer name from test name (e.g., 'adam' from 'adam_base').""" + return self.name.split("_")[0] if "_" in self.name else self.name + + @property + def variant_name(self) -> str: + """Extract variant name from test name (e.g., 'base' from 'adam_base').""" + parts = self.name.split("_", 1) + return parts[1] if len(parts) > 1 else "base" + + def to_optimi_kwargs(self) -> dict[str, Any]: + """Get kwargs for optimi optimizer.""" + return self.optimi_params.to_optimi_kwargs(self.optimi_class) + + def to_reference_kwargs(self) -> dict[str, Any]: + """Get kwargs for reference optimizer.""" + return self.reference_params.to_reference_kwargs(self.reference_class) + + def should_skip_test(self, test_type: str) -> bool: + """Check if a specific test type should be skipped.""" + return test_type in self.skip_tests + + def get_tolerance(self, dtype: torch.dtype) -> ToleranceConfig: + """Get tolerance configuration for specific dtype.""" + return self.custom_tolerances[dtype] + + # Backwards-compatible alias to support existing call sites + def get_tolerance_for_dtype(self, dtype: torch.dtype) -> ToleranceConfig: + """Backward-compatible alias for get_tolerance.""" + return self.get_tolerance(dtype) + + def get_iterations_for_test(self, test_type: str) -> int: + """Get number of iterations for specific test type.""" + if self.custom_iterations and test_type in self.custom_iterations: + return self.custom_iterations[test_type] + + # Default iterations based on test type + defaults = {"correctness": 10, "gradient_release": 5, "accumulation": 5} + return defaults.get(test_type, 10) + + def supports_l2_weight_decay(self) -> bool: + """Check if optimizer supports L2 weight decay.""" + # all optimi optimizers which support l2 weight decay have a decouple_wd parameter + return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters + + +def assert_most_approx_close( + a: torch.Tensor, + b: torch.Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, + max_error_count: int = 0, + max_error_rate: float | None = None, + name: str = "", +) -> None: + """Assert that most values in two tensors are approximately close. + + Allows for a small number of errors based on max_error_count and max_error_rate. + """ + idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) + error_count = (idx == 0).sum().item() + + if max_error_rate is not None: + if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + elif error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py new file mode 100644 index 0000000..a1e0fb3 --- /dev/null +++ b/test_new/opt_adam.py @@ -0,0 +1,25 @@ +"""Adam optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +import torch + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class AdamParams(BaseParams): + """Type-safe Adam optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.99) + eps: float = 1e-6 + + +BASE_TEST = OptimizerTest( + name="adam_base", + optimi_class=optimi.Adam, + optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_class=torch.optim.Adam, + test_decoupled_wd=False, +) diff --git a/test_new/opt_adamw.py b/test_new/opt_adamw.py new file mode 100644 index 0000000..6845617 --- /dev/null +++ b/test_new/opt_adamw.py @@ -0,0 +1,26 @@ +"""AdamW optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +import torch +from tests import reference + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class AdamWParams(BaseParams): + """Type-safe AdamW optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.99) + eps: float = 1e-8 + + +BASE_TEST = OptimizerTest( + name="adamw_base", + optimi_class=optimi.AdamW, + optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_class=torch.optim.AdamW, + fully_decoupled_reference=reference.DecoupledAdamW, +) diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py new file mode 100644 index 0000000..beaba77 --- /dev/null +++ b/test_new/opt_adan.py @@ -0,0 +1,74 @@ +"""Adan optimizer test definitions.""" + +from dataclasses import dataclass +from typing import Any + +import optimi +from tests import reference + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class AdanParams(BaseParams): + """Type-safe Adan optimizer parameters.""" + + betas: tuple[float, float, float] = (0.98, 0.92, 0.99) + eps: float = 1e-8 + weight_decouple: bool = False # For adam_wd variant (maps to no_prox in reference) + adam_wd: bool = False # For optimi optimizer + + def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: + """Adan needs special parameter conversion for no_prox.""" + kwargs = super().to_reference_kwargs(reference_class) + + # Convert weight_decouple to no_prox for reference optimizer + if "weight_decouple" in kwargs: + kwargs["no_prox"] = kwargs.pop("weight_decouple") + + # Remove adam_wd as it's not used by reference + kwargs.pop("adam_wd", None) + + return kwargs + + +# Define all Adan test variants explicitly to match original tests +ALL_TESTS = [ + OptimizerTest( + name="adan_base", + optimi_class=optimi.Adan, + optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=0), + reference_class=reference.Adan, + reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6), + custom_iterations={"correctness": 20}, # Adan bfloat16 updates are noisier, use fewer iterations for GPU + ), + OptimizerTest( + name="adan_weight_decay", + optimi_class=optimi.Adan, + optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + reference_class=reference.Adan, + reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + custom_iterations={"correctness": 20}, + ), + OptimizerTest( + name="adan_adam_wd", + optimi_class=optimi.Adan, + optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, adam_wd=True), + reference_class=reference.Adan, + reference_params=AdanParams( + lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, weight_decouple=True + ), # no_prox=True in reference + custom_iterations={"correctness": 20}, + ), + OptimizerTest( + name="adan_decoupled_lr", + optimi_class=optimi.Adan, + optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-5, decouple_lr=True), + reference_class=reference.Adan, + reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + custom_iterations={"correctness": 20}, + ), +] + +# Set BASE_TEST for auto-generation compatibility +BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py new file mode 100644 index 0000000..721f77a --- /dev/null +++ b/test_new/opt_anyadam.py @@ -0,0 +1,109 @@ +"""AnyAdam optimizer test definitions for Kahan summation precision tests.""" + +from dataclasses import dataclass + +import optimi +import torch +from tests.reference import AnyPrecisionAdamW + +from .framework import BaseParams, OptimizerTest, ToleranceConfig + + +@dataclass +class AnyAdamParams(BaseParams): + """Type-safe AnyAdam optimizer parameters with Kahan summation support.""" + + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + kahan_sum: bool = False + use_kahan_summation: bool = False + + def to_reference_kwargs(self, reference_class: type) -> dict: + """Convert parameters for AnyPrecisionAdamW reference.""" + kwargs = super().to_reference_kwargs(reference_class) + + # AnyPrecisionAdamW uses use_kahan_summation instead of kahan_sum + if "kahan_sum" in kwargs: + kwargs["use_kahan_summation"] = kwargs.pop("kahan_sum") + + # Set default precision dtypes for AnyPrecisionAdamW + if reference_class.__name__ == "AnyPrecisionAdamW": + kwargs.setdefault("momentum_dtype", torch.bfloat16) + kwargs.setdefault("variance_dtype", torch.bfloat16) + kwargs.setdefault("compensation_buffer_dtype", torch.bfloat16) + + return kwargs + + +ALL_TESTS = [ + OptimizerTest( + name="anyadam_kahan", + optimi_class=optimi.Adam, + optimi_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0, + kahan_sum=True, + ), + reference_class=AnyPrecisionAdamW, + reference_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0, + use_kahan_summation=True, + ), + only_dtypes=[torch.bfloat16], + any_precision=True, + custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=2e-2, atol=2e-3, equal_nan=False)}, + ), + OptimizerTest( + name="anyadam_kahan_wd", + optimi_class=optimi.Adam, + optimi_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0.01, + kahan_sum=True, + ), + reference_class=AnyPrecisionAdamW, + reference_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=0.01, + use_kahan_summation=True, + ), + only_dtypes=[torch.bfloat16], + any_precision=True, + custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=5e-2, atol=1e-2, equal_nan=False)}, + ), + OptimizerTest( + name="anyadam_kahan_decoupled_lr", + optimi_class=optimi.Adam, + optimi_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=1e-5, + decouple_lr=True, + kahan_sum=True, + ), + reference_class=AnyPrecisionAdamW, + reference_params=AnyAdamParams( + lr=1e-3, + betas=(0.9, 0.99), + eps=1e-6, + weight_decay=1e-2, + use_kahan_summation=True, + ), + only_dtypes=[torch.bfloat16], + any_precision=True, + custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=2e-2, atol=2e-3, equal_nan=False)}, + ), +] + +# For compatibility with auto-generation system +BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_lion.py b/test_new/opt_lion.py new file mode 100644 index 0000000..9d7d0e5 --- /dev/null +++ b/test_new/opt_lion.py @@ -0,0 +1,50 @@ +"""Lion optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +from tests.reference import lion as reference_lion + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class LionParams(BaseParams): + """Type-safe Lion optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.99) + + +BASE_TEST = OptimizerTest( + name="lion_base", + optimi_class=optimi.Lion, + optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), + reference_class=reference_lion.Lion, + reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), +) + + +# Define all Adan test variants explicitly to match original tests +ALL_TESTS = [ + OptimizerTest( + name="lion_base", + optimi_class=optimi.Lion, + optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), + reference_class=reference_lion.Lion, + reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), + ), + OptimizerTest( + name="lion_decoupled_wd", + optimi_class=optimi.Lion, + optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1, decouple_wd=True), + reference_class=reference_lion.Lion, + reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), + ), + OptimizerTest( + name="lion_decoupled_lr", + optimi_class=optimi.Lion, + optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True), + reference_class=reference_lion.Lion, + reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), + ), +] diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py new file mode 100644 index 0000000..0154cc8 --- /dev/null +++ b/test_new/opt_radam.py @@ -0,0 +1,25 @@ +"""RAdam optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +import torch + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class RAdamParams(BaseParams): + """Type-safe RAdam optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + + +BASE_TEST = OptimizerTest( + name="radam_base", + optimi_class=optimi.RAdam, + optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_class=torch.optim.RAdam, + reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), +) diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py new file mode 100644 index 0000000..89a8521 --- /dev/null +++ b/test_new/opt_ranger.py @@ -0,0 +1,34 @@ +"""Ranger optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +from tests import reference + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class RangerParams(BaseParams): + """Type-safe Ranger optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.99) + eps: float = 1e-8 + k: int = 6 # Lookahead steps + alpha: float = 0.5 # Lookahead alpha + + +# Ranger only has base test - reference doesn't perform normal weight decay step +ALL_TESTS = [ + OptimizerTest( + name="ranger_base", + optimi_class=optimi.Ranger, + optimi_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), + reference_class=reference.Ranger, + reference_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), + custom_iterations={"gradient_release": 160}, # Ranger needs longer testing due to lookahead step + ) +] + +# Set BASE_TEST for auto-generation compatibility +BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py new file mode 100644 index 0000000..8845edd --- /dev/null +++ b/test_new/opt_sgd.py @@ -0,0 +1,75 @@ +"""SGD optimizer test definitions with custom parameter handling.""" + +from dataclasses import dataclass +from typing import Any + +import optimi +import torch +from tests import reference + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class SGDParams(BaseParams): + """Type-safe SGD optimizer parameters.""" + + momentum: float = 0.0 + dampening: bool = False # Optimi uses bool instead of float + torch_init: bool = False + + def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: + """SGD needs special dampening conversion for reference optimizer.""" + kwargs = super().to_reference_kwargs(reference_class) + + # Convert dampening bool to float for reference optimizer + if "dampening" in kwargs and isinstance(kwargs["dampening"], bool): + kwargs["dampening"] = 0.9 if kwargs["dampening"] else 0.0 + + return kwargs + + +# Define all SGD test variants explicitly +ALL_TESTS = [ + OptimizerTest( + name="sgd_base", + optimi_class=optimi.SGD, + optimi_params=SGDParams(lr=1e-3, momentum=0, dampening=False, weight_decay=0), + reference_class=torch.optim.SGD, + reference_params=SGDParams(lr=1e-3, momentum=0, dampening=0, weight_decay=0), + skip_tests=["accumulation"], # SGD base skips accumulation tests + ), + OptimizerTest( + name="sgd_momentum", + optimi_class=optimi.SGD, + optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=0), + reference_class=torch.optim.SGD, + reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=0), + ), + OptimizerTest( + name="sgd_dampening", + optimi_class=optimi.SGD, + optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, weight_decay=0, torch_init=True), + reference_class=torch.optim.SGD, + reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=0), + ), + OptimizerTest( + name="sgd_weight_decay", + optimi_class=optimi.SGD, + optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False), + reference_class=torch.optim.SGD, + reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2), + skip_tests=["accumulation"], # SGD with L2 weight decay skips accumulation tests + ), + OptimizerTest( + name="sgd_decoupled_lr", + optimi_class=optimi.SGD, + optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True), + reference_class=reference.DecoupledSGDW, + reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=1e-5), + custom_iterations={"accumulation": 20}, # SGD uses fewer iterations for accumulation + ), +] + +# Set BASE_TEST for auto-generation compatibility +BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_stableadamw.py b/test_new/opt_stableadamw.py new file mode 100644 index 0000000..9a855e4 --- /dev/null +++ b/test_new/opt_stableadamw.py @@ -0,0 +1,25 @@ +"""StableAdamW optimizer test definitions.""" + +from dataclasses import dataclass + +import optimi +from tests import reference + +from .framework import BaseParams, OptimizerTest + + +@dataclass +class StableAdamWParams(BaseParams): + """Type-safe StableAdamW optimizer parameters.""" + + betas: tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + + +BASE_TEST = OptimizerTest( + name="stableadamw_base", + optimi_class=optimi.StableAdamW, + optimi_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_class=reference.StableAdamWUnfused, + reference_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), +) diff --git a/test_new/optimizer_tests.py b/test_new/optimizer_tests.py new file mode 100644 index 0000000..6e139df --- /dev/null +++ b/test_new/optimizer_tests.py @@ -0,0 +1,282 @@ +"""Automatic test discovery and generation system for optimizer tests. + +This module provides functionality to automatically discover optimizer test definitions +from test files and generate torch test variants, creating a comprehensive test registry. +""" + +import importlib +import warnings +from pathlib import Path +from typing import Any + +from .framework import OptimizerTest + + +def discover_optimizer_tests() -> list[OptimizerTest]: + """Automatically discover and generate tests from all test modules. + + Scans the test_new directory for Python files containing optimizer test definitions. + Each module should define either BASE_TEST (for auto-generation) or ALL_TESTS (custom). + + Returns: + List of all discovered and generated OptimizerTest instances. + """ + all_tests = [] + test_dir = Path(__file__).parent + test_files = [f for f in test_dir.glob("opt_*.py") if f.is_file()] + + for test_file in test_files: + module_name = test_file.stem + try: + # Import the test module using package-relative import + # __package__ will be 'optimi.test_new' when this file is imported as a package module + module = importlib.import_module(f".{module_name}", package=__package__) + + # Check for ALL_TESTS first (custom test definitions) + if hasattr(module, "ALL_TESTS"): + tests = getattr(module, "ALL_TESTS") + if isinstance(tests, list) and all(isinstance(t, OptimizerTest) for t in tests): + all_tests.extend(tests) + else: + warnings.warn(f"Module {module_name} has ALL_TESTS but it's not a list of OptimizerTest instances") + + # Check for BASE_TEST (for auto-generation) + elif hasattr(module, "BASE_TEST"): + base_test = getattr(module, "BASE_TEST") + if isinstance(base_test, OptimizerTest): + # Generate torch variants from base test + generated_tests = auto_generate_variants(base_test) + all_tests.extend(generated_tests) + else: + warnings.warn(f"Module {module_name} has BASE_TEST but it's not an OptimizerTest instance") + + else: + warnings.warn(f"Module {module_name} has neither BASE_TEST nor ALL_TESTS defined") + + except ImportError as e: + warnings.warn(f"Failed to import test module {module_name}: {e}") + except Exception as e: + warnings.warn(f"Error processing test module {module_name}: {e}") + + return all_tests + + +def auto_generate_variants(base_test: OptimizerTest) -> list[OptimizerTest]: + """Automatically generate torch test variants from a base test. + + Generates the following torch variants: + - base: Original test with no weight decay + - weight_decay: Test with weight decay enabled + - decoupled_wd: Test with decoupled weight decay + - decoupled_lr: Test with decoupled learning rate + + Args: + base_test: Base OptimizerTest to generate variants from. + + Returns: + List of generated test variants. + """ + variants = [] + + # Base variant (ensure weight_decay is 0) + base_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.0) + base_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.0) + + base_variant = OptimizerTest( + name=f"{base_test.optimizer_name}_base", + optimi_class=base_test.optimi_class, + optimi_params=base_params, + reference_class=base_test.reference_class, + reference_params=base_ref_params, + skip_tests=base_test.skip_tests.copy(), + only_dtypes=base_test.only_dtypes, + any_precision=base_test.any_precision, + custom_iterations=base_test.custom_iterations, + custom_tolerances=base_test.custom_tolerances, + ) + variants.append(base_variant) + + # L2 weight decay variant + if base_test.supports_l2_weight_decay(): + l2_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.01) + l2_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01) + + l2_variant = OptimizerTest( + name=f"{base_test.optimizer_name}_l2_wd", + optimi_class=base_test.optimi_class, + optimi_params=l2_params, + reference_class=base_test.reference_class, + reference_params=l2_ref_params, + skip_tests=base_test.skip_tests.copy(), + only_dtypes=base_test.only_dtypes, + any_precision=base_test.any_precision, + custom_iterations=base_test.custom_iterations, + custom_tolerances=base_test.custom_tolerances, + ) + variants.append(l2_variant) + + # Decoupled weight decay variant + if base_test.test_decoupled_wd: + decoupled_wd_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.01, decouple_wd=True) + decoupled_wd_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01, decouple_wd=True) + + decoupled_wd_variant = OptimizerTest( + name=f"{base_test.optimizer_name}_decoupled_wd", + optimi_class=base_test.optimi_class, + optimi_params=decoupled_wd_params, + reference_class=base_test.reference_class, + reference_params=decoupled_wd_ref_params, + skip_tests=base_test.skip_tests.copy(), + only_dtypes=base_test.only_dtypes, + any_precision=base_test.any_precision, + custom_iterations=base_test.custom_iterations, + custom_tolerances=base_test.custom_tolerances, + ) + variants.append(decoupled_wd_variant) + + # Decoupled learning rate variant + decoupled_lr_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=1e-5, decouple_lr=True) + if base_test.fully_decoupled_reference is not None: + decoupled_lr_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=1e-5, decouple_lr=True) + reference_class = base_test.fully_decoupled_reference + else: + decoupled_lr_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01, decouple_lr=True) + reference_class = base_test.reference_class + + decoupled_lr_variant = OptimizerTest( + name=f"{base_test.optimizer_name}_decoupled_lr", + optimi_class=base_test.optimi_class, + optimi_params=decoupled_lr_params, + reference_class=reference_class, + reference_params=decoupled_lr_ref_params, + skip_tests=base_test.skip_tests.copy(), + only_dtypes=base_test.only_dtypes, + any_precision=base_test.any_precision, + custom_iterations=base_test.custom_iterations, + custom_tolerances=base_test.custom_tolerances, + ) + variants.append(decoupled_lr_variant) + + return variants + + +def _copy_params_with_overrides(params: Any, **overrides: Any) -> Any: + """Create a copy of parameter dataclass with specified overrides. + + Args: + params: Original parameter dataclass instance. + **overrides: Field values to override. + + Returns: + New parameter instance with overrides applied. + """ + # Get all current field values + current_values = {} + for field_info in params.__dataclass_fields__.values(): + current_values[field_info.name] = getattr(params, field_info.name) + + # Apply overrides + current_values.update(overrides) + + # Create new instance + return type(params)(**current_values) + + +# Central registry of all discovered tests +ALL_OPTIMIZER_TESTS: list[OptimizerTest] = [] + + +def _initialize_test_registry() -> None: + """Initialize the test registry by discovering all tests.""" + global ALL_OPTIMIZER_TESTS + if not ALL_OPTIMIZER_TESTS: # Only initialize once + ALL_OPTIMIZER_TESTS = discover_optimizer_tests() + + +def get_tests_by_optimizer(optimizer_name: str) -> list[OptimizerTest]: + """Get all tests for a specific optimizer. + + Args: + optimizer_name: Name of the optimizer (e.g., 'adam', 'sgd'). + + Returns: + List of OptimizerTest instances for the specified optimizer. + """ + _initialize_test_registry() + return [test for test in ALL_OPTIMIZER_TESTS if test.optimizer_name == optimizer_name] + + +def get_tests_by_variant(variant_name: str) -> list[OptimizerTest]: + """Get all tests for a specific variant across all optimizers. + + Args: + variant_name: Name of the variant (e.g., 'base', 'weight_decay'). + + Returns: + List of OptimizerTest instances for the specified variant. + """ + _initialize_test_registry() + return [test for test in ALL_OPTIMIZER_TESTS if test.variant_name == variant_name] + + +def get_test_by_name(name: str) -> OptimizerTest | None: + """Get a specific test by its full name. + + Args: + name: Full test name (e.g., 'adam_base', 'sgd_momentum'). + + Returns: + OptimizerTest instance if found, None otherwise. + """ + _initialize_test_registry() + for test in ALL_OPTIMIZER_TESTS: + if test.name == name: + return test + return None + + +def get_all_optimizer_names() -> list[str]: + """Get list of all available optimizer names. + + Returns: + Sorted list of unique optimizer names. + """ + _initialize_test_registry() + optimizer_names = {test.optimizer_name for test in ALL_OPTIMIZER_TESTS} + return sorted(optimizer_names) + + +def get_all_variant_names() -> list[str]: + """Get list of all available variant names. + + Returns: + Sorted list of unique variant names. + """ + _initialize_test_registry() + variant_names = {test.variant_name for test in ALL_OPTIMIZER_TESTS} + return sorted(variant_names) + + +def get_test_count() -> int: + """Get total number of discovered tests. + + Returns: + Total count of OptimizerTest instances. + """ + _initialize_test_registry() + return len(ALL_OPTIMIZER_TESTS) + + +def print_test_summary() -> None: + """Print a summary of discovered tests for debugging.""" + _initialize_test_registry() + + print(f"Discovered {len(ALL_OPTIMIZER_TESTS)} optimizer tests:") + print(f"Optimizers: {', '.join(get_all_optimizer_names())}") + print(f"Variants: {', '.join(get_all_variant_names())}") + + # Group by optimizer + for optimizer_name in get_all_optimizer_names(): + tests = get_tests_by_optimizer(optimizer_name) + test_names = [test.name for test in tests] + print(f" {optimizer_name}: {', '.join(test_names)}") diff --git a/test_new/pytest_integration.py b/test_new/pytest_integration.py new file mode 100644 index 0000000..2af5dda --- /dev/null +++ b/test_new/pytest_integration.py @@ -0,0 +1,155 @@ +"""Pytest integration functions for automatic mark generation. + +This module provides functions to create pytest parameters with automatic marks +for optimizers, devices, dtypes, and backends, enabling flexible test execution. +""" + +import pytest +import torch + +from .optimizer_tests import ALL_OPTIMIZER_TESTS, _initialize_test_registry + +_CACHED_TESTS: list | None = None + + +def create_marked_optimizer_tests(): + """Create optimizer test parameters with automatic marks. + + Each test gets marked with its optimizer name for targeted test execution. + + Returns: + list[pytest.param]: List of pytest parameters with optimizer marks. + """ + # Call discover_optimizer_tests directly to avoid global variable issues + from .optimizer_tests import discover_optimizer_tests + + global _CACHED_TESTS + if _CACHED_TESTS is None: + _CACHED_TESTS = discover_optimizer_tests() + return [pytest.param(test, marks=pytest.mark.__getattr__(test.optimizer_name), id=test.name) for test in _CACHED_TESTS] + + +def create_marked_device_types(): + """Create device type parameters with marks. + + Returns: + list[pytest.param]: List of device parameters with device marks. + """ + return [ + pytest.param("cpu", marks=pytest.mark.cpu, id="cpu"), + pytest.param("gpu", marks=pytest.mark.gpu, id="gpu"), + ] + + +def create_marked_dtypes(): + """Create dtype parameters with marks. + + Only includes float32 and bfloat16 as specified in requirements. + + Returns: + list[pytest.param]: List of dtype parameters with dtype marks. + """ + return [ + pytest.param(torch.float32, marks=pytest.mark.float32, id="float32"), + pytest.param(torch.bfloat16, marks=pytest.mark.bfloat16, id="bfloat16"), + ] + + +def create_marked_backends(): + """Create backend parameters with marks. + + Only includes torch and triton backends as specified in requirements. + + Returns: + list[pytest.param]: List of backend parameters with backend marks. + """ + return [ + pytest.param("torch", marks=pytest.mark.torch, id="torch"), + pytest.param("triton", marks=pytest.mark.triton, id="triton"), + ] + + +def create_gpu_only_device_types(): + """Create device type parameters for GPU-only tests. + + Returns: + list[pytest.param]: List containing only GPU device parameter. + """ + return [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")] + + +def create_float32_only_dtypes(): + """Create dtype parameters for float32-only tests. + + Returns: + list[pytest.param]: List containing only float32 dtype parameter. + """ + return [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")] + + +def get_optimizer_marks(): + """Get all available optimizer marks. + + Returns: + list[str]: List of optimizer names that can be used as pytest marks. + """ + _initialize_test_registry() + return sorted({test.optimizer_name for test in ALL_OPTIMIZER_TESTS}) + + +def get_device_marks(): + """Get all available device marks. + + Returns: + list[str]: List of device names that can be used as pytest marks. + """ + return ["cpu", "gpu"] + + +def get_dtype_marks(): + """Get all available dtype marks. + + Returns: + list[str]: List of dtype names that can be used as pytest marks. + """ + return ["float32", "bfloat16"] + + +def get_backend_marks(): + """Get all available backend marks. + + Returns: + list[str]: List of backend names that can be used as pytest marks. + """ + return ["torch", "triton"] + + +def create_test_matrix(): + """Create the complete test matrix for all combinations. + + Returns: + dict: Dictionary containing all parameter combinations for testing. + """ + return { + "optimizer_tests": create_marked_optimizer_tests(), + "device_types": create_marked_device_types(), + "dtypes": create_marked_dtypes(), + "backends": create_marked_backends(), + "gpu_only_devices": create_gpu_only_device_types(), + "float32_only_dtypes": create_float32_only_dtypes(), + } + + +def print_mark_summary(): + """Print a summary of available marks for debugging.""" + print("Available pytest marks:") + print(f" Optimizers: {', '.join(get_optimizer_marks())}") + print(f" Devices: {', '.join(get_device_marks())}") + print(f" Dtypes: {', '.join(get_dtype_marks())}") + print(f" Backends: {', '.join(get_backend_marks())}") + + print( + f"\nTotal test combinations: {len(create_marked_optimizer_tests())} optimizers × " + f"{len(get_device_marks())} devices × {len(get_dtype_marks())} dtypes × " + f"{len(get_backend_marks())} backends" + ) diff --git a/test_new/test_optimizers.py b/test_new/test_optimizers.py new file mode 100644 index 0000000..4be7bd4 --- /dev/null +++ b/test_new/test_optimizers.py @@ -0,0 +1,556 @@ +"""Main test class with comprehensive test methods for the unified optimizer test framework. + +This module provides the TestOptimizers class that implements all test types: +- test_optimizer_correctness: Validates optimizer correctness against reference implementations +- test_gradient_release: Tests gradient release functionality (GPU-only) +- test_optimizer_accumulation: Tests optimizer accumulation functionality (GPU-only) +""" + +import io + +import pytest +import torch +from optimi import prepare_for_gradient_release, remove_gradient_release +from optimi.utils import MIN_TORCH_2_6 +from torch import Tensor + +from .framework import OptimizerTest, ToleranceConfig, assert_most_approx_close +from .pytest_integration import ( + create_float32_only_dtypes, + create_gpu_only_device_types, + create_marked_backends, + create_marked_device_types, + create_marked_dtypes, + create_marked_optimizer_tests, +) + + +class MLP(torch.nn.Module): + """Simple MLP model for testing optimizer behavior.""" + + def __init__(self, input_size: int, hidden_size: int, device: torch.device, dtype: torch.dtype): + super().__init__() + self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) + self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) + self.act = torch.nn.Mish() + self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class TestOptimizers: + """Main test class for comprehensive optimizer testing.""" + + @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) + @pytest.mark.parametrize("device_type", create_marked_device_types()) + @pytest.mark.parametrize("dtype", create_marked_dtypes()) + @pytest.mark.parametrize("backend", create_marked_backends()) + def test_optimizer_correctness( + self, + optimizer_test: OptimizerTest, + device_type: str, + dtype: torch.dtype, + backend: str, + gpu_device: str, + ) -> None: + """Test optimizer correctness against reference implementation. + + Validates that the optimi optimizer produces results consistent with + the reference PyTorch optimizer across different configurations. + """ + # Skip test if conditions don't match + if self._should_skip("correctness", optimizer_test, device_type, dtype, backend): + pytest.skip(f"Skipping {optimizer_test.name} correctness test for {device_type}/{dtype}/{backend}") + + # Log a random seed for reproducibility while keeping randomness + seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) + print(f"[seed] correctness: {optimizer_test.name} {device_type}/{dtype}/{backend} -> {seed}") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Determine actual device + device = torch.device(gpu_device if device_type == "gpu" else "cpu") + + # Run the correctness test + self._run_correctness_test(optimizer_test, device, dtype, backend) + + @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) + @pytest.mark.parametrize("device_type", create_gpu_only_device_types()) + @pytest.mark.parametrize("dtype", create_float32_only_dtypes()) + @pytest.mark.parametrize("backend", create_marked_backends()) + def test_gradient_release( + self, + optimizer_test: OptimizerTest, + device_type: str, + dtype: torch.dtype, + backend: str, + gpu_device: str, + ) -> None: + """Test gradient release functionality (GPU only). + + Validates that gradient release produces consistent results with + standard optimizer behavior while freeing memory during backprop. + """ + # Skip test if conditions don't match + if self._should_skip("gradient_release", optimizer_test, device_type, dtype, backend): + pytest.skip(f"Skipping {optimizer_test.name} gradient_release test for {device_type}/{dtype}/{backend}") + + # Log a random seed for reproducibility while keeping randomness + seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) + print(f"[seed] gradient_release: {optimizer_test.name} {device_type}/{dtype}/{backend} -> {seed}") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Determine actual device (always GPU for this test) + device = torch.device(gpu_device) + + # Run the gradient release test + self._run_gradient_release_test(optimizer_test, device, dtype, backend) + + @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) + @pytest.mark.parametrize("device_type", create_gpu_only_device_types()) + @pytest.mark.parametrize("dtype", create_float32_only_dtypes()) + @pytest.mark.parametrize("backend", create_marked_backends()) + def test_optimizer_accumulation( + self, + optimizer_test: OptimizerTest, + device_type: str, + dtype: torch.dtype, + backend: str, + gpu_device: str, + ) -> None: + """Test optimizer accumulation functionality (GPU only). + + Validates that optimizer accumulation produces results consistent + with gradient accumulation while being more memory efficient. + """ + # Skip test if conditions don't match + if self._should_skip("accumulation", optimizer_test, device_type, dtype, backend): + pytest.skip(f"Skipping {optimizer_test.name} accumulation test for {device_type}/{dtype}/{backend}") + + # Determine actual device (always GPU for this test) + device = torch.device(gpu_device) + + # Run the accumulation test + self._run_accumulation_test(optimizer_test, device, dtype, backend) + + def _prepare_kwargs(self, optimizer_test: OptimizerTest, backend: str) -> tuple[dict, dict]: + """Prepare reference and optimi kwargs including backend-specific flags.""" + reference_kwargs = optimizer_test.to_reference_kwargs() + optimi_kwargs = optimizer_test.to_optimi_kwargs() + if backend == "triton": + optimi_kwargs["triton"] = True + else: + optimi_kwargs["foreach"] = False + optimi_kwargs["triton"] = False + return reference_kwargs, optimi_kwargs + + def _should_skip( + self, + test_type: str, + optimizer_test: OptimizerTest, + device_type: str, + dtype: torch.dtype, + backend: str, + ) -> bool: + """Comprehensive test skipping logic with all conditions. + + Args: + test_type: Type of test ('correctness', 'gradient_release', 'accumulation') + optimizer_test: The optimizer test configuration + device_type: Device type ('cpu' or 'gpu') + dtype: Data type (torch.float32 or torch.bfloat16) + backend: Backend type ('torch' or 'triton') + + Returns: + True if test should be skipped, False otherwise + """ + # Check if test type is explicitly skipped + if optimizer_test.should_skip_test(test_type): + return True + + # Respect per-test dtype constraints if provided + if getattr(optimizer_test, "only_dtypes", None): + if dtype not in optimizer_test.only_dtypes: + return True + + # Skip triton tests on CPU + if backend == "triton" and device_type == "cpu": + return True + + # Skip triton tests if PyTorch version is too old + if backend == "triton" and not MIN_TORCH_2_6: + return True + + # Skip GPU tests if no GPU is available + if device_type == "gpu" and not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())): + return True + + # Gradient release and accumulation are GPU-only tests + if test_type in ["gradient_release", "accumulation"] and device_type == "cpu": + return True + + # Skip bfloat16 on CPU for most optimizers (matches original test behavior) + # Only anyadam tests bfloat16 on CPU for precision testing + if device_type == "cpu" and dtype == torch.bfloat16 and not optimizer_test.name.startswith("anyadam"): + return True + + return False + + def _run_correctness_test( + self, + optimizer_test: OptimizerTest, + device: torch.device, + dtype: torch.dtype, + backend: str, + ) -> None: + """Core correctness test implementation. + + Creates two identical models, runs them with optimi and reference optimizers, + and validates that they produce consistent results. + """ + # Get test configuration + iterations = optimizer_test.get_iterations_for_test("correctness") + tolerance = optimizer_test.get_tolerance_for_dtype(dtype) + + # Determine model dimensions and error handling based on device + if device.type == "cpu": + dim1, dim2 = 64, 128 + batch_size = 1 + max_error_count = 2 + else: + dim1, dim2 = 256, 512 + batch_size = 32 + max_error_count = 5 + + # Set max_error_rate for bfloat16 like original tests + max_error_rate = None + if dtype == torch.bfloat16: + max_error_rate = 0.01 # Allow 1% of values to be outside tolerance + + # Skip 1x1 tests + if dim1 == 1 and dim2 == 1: + pytest.skip("Skipping 1x1 optimizer test") + + # Create models + m1 = MLP(dim1, dim2, device=device, dtype=dtype) + m2 = MLP(dim1, dim2, device=device, dtype=dtype) + m2.load_state_dict(m1.state_dict()) + + # Convert model parameters to float for non-any_precision testing + if not optimizer_test.any_precision and dtype != torch.float32: + for p in m1.parameters(): + p.data = p.data.float() + + # Create optimizers + reference_class = optimizer_test.reference_class + reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) + + reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) + optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) + + # Training loop with state dict testing + buffer = io.BytesIO() + + for i in range(iterations): + # Generate training data + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.detach().clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.detach().clone() + + # Convert inputs to float for non-any_precision testing + if not optimizer_test.any_precision and dtype != torch.float32: + input1 = input1.float() + target1 = target1.float() + + # Forward pass + output1 = m1(input1) + output2 = m2(input2) + + # Loss calculation + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + + # Backward pass + loss1.backward() + loss2.backward() + + # Optimizer step + reference_optimizer.step() + optimi_optimizer.step() + + # Zero gradients + reference_optimizer.zero_grad() + optimi_optimizer.zero_grad() + + # Compare model weights + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc1: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc2: ", + ) + + # Test state_dict saving and loading periodically + if i % max(1, iterations // 10) == 0 and i > 0: + # Save optimizer state + torch.save(optimi_optimizer.state_dict(), buffer) + buffer.seek(0) + # Load checkpoint + ckpt = torch.load(buffer, weights_only=True) + # Recreate optimizer and load its state + optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer.load_state_dict(ckpt) + # Clear buffer + buffer.seek(0) + buffer.truncate(0) + + # Verify models are still aligned after state_dict loading + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc1 after load: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc2 after load: ", + ) + + def _run_gradient_release_test( + self, + optimizer_test: OptimizerTest, + device: torch.device, + dtype: torch.dtype, + backend: str, + ) -> None: + """Core gradient release test implementation. + + Compares gradient release behavior with torch PyTorch optimizer hooks + and regular optimizer behavior to ensure consistency. + """ + + def optimizer_hook(parameter) -> None: + torch_optimizers[parameter].step() + torch_optimizers[parameter].zero_grad() + + # Get test configuration + iterations = optimizer_test.get_iterations_for_test("gradient_release") + tolerance = optimizer_test.get_tolerance_for_dtype(dtype) + + # Set default tolerances for gradient release (slightly more lenient) + if optimizer_test.custom_tolerances is None: + if dtype == torch.float32: + tolerance = ToleranceConfig(rtol=1e-5, atol=2e-6) + elif dtype == torch.bfloat16: + tolerance = ToleranceConfig(rtol=1e-2, atol=2e-3) + + # Since Lion & Adan can have noisy updates, allow up to 12 errors + max_error_count = 12 + + # Model dimensions for gradient release tests + dim1, dim2 = 128, 256 + batch_size = 32 + + # Create three identical models + m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer + m2 = MLP(dim1, dim2, device=device, dtype=dtype) # PyTorch hooks + m3 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi gradient release + m2.load_state_dict(m1.state_dict()) + m3.load_state_dict(m1.state_dict()) + + # Create optimizers + reference_class = optimizer_test.reference_class + reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) + + # Regular optimizer + regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) + + # PyTorch Method: taken from https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} + + pytorch_hooks = [] + for p in m2.parameters(): + pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) + + # Optimi Method with gradient release + optimi_kwargs["gradient_release"] = True + optimi_optimizer = optimizer_test.optimi_class(m3.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m3, optimi_optimizer) + + # Training loop + for i in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.clone() + input3 = input1.clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.clone() + target3 = target1.clone() + + output1 = m1(input1) + output2 = m2(input2) + output3 = m3(input3) + + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + loss3 = torch.nn.functional.mse_loss(output3, target3) + + loss1.backward() + loss2.backward() + loss3.backward() + + regular_optimizer.step() + regular_optimizer.zero_grad() + + # Simulate framework optimizer step (randomly enabled) + framework_opt_step = torch.rand(1).item() > 0.5 + if framework_opt_step: + optimi_optimizer.step() + optimi_optimizer.zero_grad() + + # Compare results + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc1.weight, + m3.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + name="PyTorch-Optimi: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m3.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + name="PyTorch-Optimi: ", + ) + + # Cleanup + for h in pytorch_hooks: + h.remove() + remove_gradient_release(m3) + + def _run_accumulation_test( + self, + optimizer_test: OptimizerTest, + device: torch.device, + dtype: torch.dtype, + backend: str, + ) -> None: + """Core accumulation test implementation. + + Tests optimizer accumulation functionality which approximates gradient + accumulation by accumulating directly into optimizer states. + """ + # Get test configuration + iterations = optimizer_test.get_iterations_for_test("accumulation") + + # Since optimizer accumulation approximates gradient accumulation, + # the tolerances are high despite the low number of iterations + max_error_rate = 0.035 + tolerance = ToleranceConfig(rtol=1e-2, atol=1e-2) + + # Model dimensions for accumulation tests + dim1, dim2 = 128, 256 + batch_size = 32 + + # Create two identical models + m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer + m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation + m2.load_state_dict(m1.state_dict()) + + # Create optimizers + reference_class = optimizer_test.reference_class + reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) + + # Regular optimizer + regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) + + # Optimi optimizer with gradient release for accumulation + optimi_kwargs["gradient_release"] = True + optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m2, optimi_optimizer) + + gradient_accumulation_steps = 4 + + # Training loop + for i in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.clone() + + # Set accumulation mode + optimi_optimizer.optimizer_accumulation = (i + 1) % gradient_accumulation_steps != 0 + + output1 = m1(input1) + output2 = m2(input2) + + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + + loss1.backward() + loss2.backward() + + # Only step regular optimizer when not accumulating + if not optimi_optimizer.optimizer_accumulation: + regular_optimizer.step() + regular_optimizer.zero_grad() + + # Simulate framework optimizer step (randomly enabled) + framework_opt_step = torch.rand(1).item() > 0.5 + if framework_opt_step: + optimi_optimizer.step() + optimi_optimizer.zero_grad() + + # Unlike other tests, compare that the weights are in the same approximate range at the end of training + assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) + assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) + + # Cleanup + remove_gradient_release(m2) From e0244bd4695175d1852a377b8ee1a70eeb60c3cb Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 12 Aug 2025 12:58:41 -0500 Subject: [PATCH 02/16] attempt to fix radam --- test_new/opt_radam.py | 19 ++++++++++++++----- test_new/test_optimizers.py | 8 ++++---- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index 0154cc8..4594ce5 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -1,11 +1,12 @@ """RAdam optimizer test definitions.""" -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any import optimi import torch -from .framework import BaseParams, OptimizerTest +from .framework import BaseParams, OptimizerTest, ToleranceConfig @dataclass @@ -13,13 +14,21 @@ class RAdamParams(BaseParams): """Type-safe RAdam optimizer parameters.""" betas: tuple[float, float] = (0.9, 0.999) - eps: float = 1e-8 + eps: float = 1e-6 + decoupled_weight_decay: bool = field(default=False) + + def __post_init__(self): + if self.decouple_wd: + self.decoupled_weight_decay = True + elif self.decouple_lr: + self.decoupled_weight_decay = True BASE_TEST = OptimizerTest( name="radam_base", optimi_class=optimi.RAdam, - optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), reference_class=torch.optim.RAdam, - reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), + custom_tolerances={torch.float32: ToleranceConfig(max_error_rate=0.001)}, ) diff --git a/test_new/test_optimizers.py b/test_new/test_optimizers.py index 4be7bd4..a35dccb 100644 --- a/test_new/test_optimizers.py +++ b/test_new/test_optimizers.py @@ -298,7 +298,7 @@ def _run_correctness_test( atol=tolerance.atol, rtol=tolerance.rtol, max_error_count=max_error_count, - max_error_rate=max_error_rate, + max_error_rate=tolerance.max_error_rate, name="fc1: ", ) assert_most_approx_close( @@ -307,7 +307,7 @@ def _run_correctness_test( atol=tolerance.atol, rtol=tolerance.rtol, max_error_count=max_error_count, - max_error_rate=max_error_rate, + max_error_rate=tolerance.max_error_rate, name="fc2: ", ) @@ -332,7 +332,7 @@ def _run_correctness_test( atol=tolerance.atol, rtol=tolerance.rtol, max_error_count=max_error_count, - max_error_rate=max_error_rate, + max_error_rate=tolerance.max_error_rate, name="fc1 after load: ", ) assert_most_approx_close( @@ -341,7 +341,7 @@ def _run_correctness_test( atol=tolerance.atol, rtol=tolerance.rtol, max_error_count=max_error_count, - max_error_rate=max_error_rate, + max_error_rate=tolerance.max_error_rate, name="fc2 after load: ", ) From 7672048912dc9f237faa351a988fd277f7b1d70f Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 13 Aug 2025 11:56:44 -0500 Subject: [PATCH 03/16] small fixes --- test_new/opt_adam.py | 2 +- test_new/opt_adamw.py | 4 ++-- test_new/opt_radam.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py index a1e0fb3..c9674ea 100644 --- a/test_new/opt_adam.py +++ b/test_new/opt_adam.py @@ -19,7 +19,7 @@ class AdamParams(BaseParams): BASE_TEST = OptimizerTest( name="adam_base", optimi_class=optimi.Adam, - optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), reference_class=torch.optim.Adam, test_decoupled_wd=False, ) diff --git a/test_new/opt_adamw.py b/test_new/opt_adamw.py index 6845617..c7ce386 100644 --- a/test_new/opt_adamw.py +++ b/test_new/opt_adamw.py @@ -14,13 +14,13 @@ class AdamWParams(BaseParams): """Type-safe AdamW optimizer parameters.""" betas: tuple[float, float] = (0.9, 0.99) - eps: float = 1e-8 + eps: float = 1e-6 BASE_TEST = OptimizerTest( name="adamw_base", optimi_class=optimi.AdamW, - optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), reference_class=torch.optim.AdamW, fully_decoupled_reference=reference.DecoupledAdamW, ) diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index 4594ce5..03f187d 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -1,7 +1,7 @@ """RAdam optimizer test definitions.""" +import inspect from dataclasses import dataclass, field -from typing import Any import optimi import torch @@ -13,8 +13,8 @@ class RAdamParams(BaseParams): """Type-safe RAdam optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.999) - eps: float = 1e-6 + betas: tuple[float, float] = (0.9, 0.99) + eps: float = 1e-8 decoupled_weight_decay: bool = field(default=False) def __post_init__(self): @@ -31,4 +31,5 @@ def __post_init__(self): reference_class=torch.optim.RAdam, reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), custom_tolerances={torch.float32: ToleranceConfig(max_error_rate=0.001)}, + test_decoupled_wd="decoupled_weight_decay" in inspect.signature(torch.optim.RAdam.__init__).parameters, ) From f603ad59984f19ef607f1cbd7e3757e254f7dbe4 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 26 Aug 2025 08:46:19 -0500 Subject: [PATCH 04/16] improve new tests --- test_new/__init__.py | 71 ---- test_new/cases.py | 199 ++++++++++ test_new/conftest.py | 136 ++----- test_new/framework.py | 166 --------- test_new/opt_adam.py | 15 +- test_new/opt_adamw.py | 14 +- test_new/opt_adan.py | 34 +- test_new/opt_anyadam.py | 82 +--- test_new/opt_lion.py | 24 +- test_new/opt_radam.py | 16 +- test_new/opt_ranger.py | 15 +- test_new/opt_sgd.py | 31 +- test_new/opt_stableadamw.py | 10 +- test_new/optimizer_tests.py | 282 -------------- test_new/pytest_integration.py | 155 -------- test_new/runners.py | 360 ++++++++++++++++++ test_new/test_optimizers.py | 556 ---------------------------- test_new/test_optimizers_unified.py | 114 ++++++ 18 files changed, 783 insertions(+), 1497 deletions(-) create mode 100644 test_new/cases.py delete mode 100644 test_new/framework.py delete mode 100644 test_new/optimizer_tests.py delete mode 100644 test_new/pytest_integration.py create mode 100644 test_new/runners.py delete mode 100644 test_new/test_optimizers.py create mode 100644 test_new/test_optimizers_unified.py diff --git a/test_new/__init__.py b/test_new/__init__.py index b4e3a8f..8b13789 100644 --- a/test_new/__init__.py +++ b/test_new/__init__.py @@ -1,72 +1 @@ -""" -Unified optimizer test framework. -This module provides a simplified, dataclass-based approach to optimizer testing -that replaces the complex OptimizerSpec/OptimizerVariant architecture. -""" - -__version__ = "1.0.0" - -# Core framework components -from .framework import BaseParams, OptimizerTest, ToleranceConfig - -# Test discovery and registry -from .optimizer_tests import ( - ALL_OPTIMIZER_TESTS, - auto_generate_variants, - discover_optimizer_tests, - get_all_optimizer_names, - get_all_variant_names, - get_test_by_name, - get_test_count, - get_tests_by_optimizer, - get_tests_by_variant, - print_test_summary, -) - -# Pytest integration -from .pytest_integration import ( - create_marked_backends, - create_marked_device_types, - create_marked_dtypes, - create_marked_optimizer_tests, - create_float32_only_dtypes, - create_gpu_only_device_types, - create_test_matrix, - get_backend_marks, - get_device_marks, - get_dtype_marks, - get_optimizer_marks, - print_mark_summary, -) - -__all__ = [ - # Core framework - "BaseParams", - "OptimizerTest", - "ToleranceConfig", - # Test discovery - "ALL_OPTIMIZER_TESTS", - "auto_generate_variants", - "discover_optimizer_tests", - "get_all_optimizer_names", - "get_all_variant_names", - "get_test_by_name", - "get_test_count", - "get_tests_by_optimizer", - "get_tests_by_variant", - "print_test_summary", - # Pytest integration - "create_marked_backends", - "create_marked_device_types", - "create_marked_dtypes", - "create_marked_optimizer_tests", - "create_float32_only_dtypes", - "create_gpu_only_device_types", - "create_test_matrix", - "get_backend_marks", - "get_device_marks", - "get_dtype_marks", - "get_optimizer_marks", - "print_mark_summary", -] diff --git a/test_new/cases.py b/test_new/cases.py new file mode 100644 index 0000000..4ac6a19 --- /dev/null +++ b/test_new/cases.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import importlib +from dataclasses import dataclass, field, replace +from pathlib import Path +from typing import Any + +import torch +from optimi.optimizer import OptimiOptimizer +from torch.optim import Optimizer + + +@dataclass +class Tolerance: + atol: float = 1e-6 + rtol: float = 1e-5 + max_error_rate: float = 5e-4 + equal_nan: bool = False + + +@dataclass +class BaseParams: + lr: float = 1e-3 + weight_decay: float = 0.0 + decouple_wd: bool = False + decouple_lr: bool = False + triton: bool = False + + def with_(self, **overrides: Any) -> "BaseParams": + return replace(self, **overrides) + + def _kwargs_for(self, cls: type | None) -> dict[str, Any]: + import inspect + from dataclasses import asdict + + if cls is None: + return {} + sig = inspect.signature(cls.__init__) + ok = set(sig.parameters) - {"self"} + return {k: v for k, v in asdict(self).items() if k in ok} + + def to_optimi_kwargs(self, cls: type[OptimiOptimizer]) -> dict[str, Any]: + return self._kwargs_for(cls) + + def to_reference_kwargs(self, cls: type[Optimizer]) -> dict[str, Any]: + return self._kwargs_for(cls) + + +@dataclass +class Case: + # Identification + name: str # e.g. "adam_base" + + # Classes + params + optimi_class: type[OptimiOptimizer] + optimi_params: BaseParams + reference_class: type[Optimizer] + reference_params: BaseParams | None = None + + # Optional fully decoupled reference for decoupled-lr variant + fully_decoupled_reference: type[Optimizer] | None = None + + # Behavior / constraints + test_decoupled_wd: bool = True + skip_tests: list[str] = field(default_factory=list) + any_precision: bool = False + custom_iterations: dict[str, int] | None = None + custom_tolerances: dict[torch.dtype, Tolerance] | None = None + only_dtypes: list[torch.dtype] | None = None + + def __post_init__(self): + if self.reference_params is None: + self.reference_params = self.optimi_params + if self.custom_tolerances is None: + self.custom_tolerances = {} + # reasonable defaults; override per-case as needed + self.custom_tolerances.setdefault(torch.float32, Tolerance()) + self.custom_tolerances.setdefault(torch.bfloat16, Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01)) + self.custom_tolerances.setdefault(torch.float16, Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01)) + + @property + def optimizer_name(self) -> str: + return self.name.split("_", 1)[0] + + @property + def variant_name(self) -> str: + return self.name.split("_", 1)[1] if "_" in self.name else "base" + + def to_optimi_kwargs(self, backend: str | None = None) -> dict[str, Any]: + # Both new BaseParams and legacy test_new.framework.BaseParams expose this + kw = self.optimi_params.to_optimi_kwargs(self.optimi_class) + + # Centralize backend controls so runners don't mutate kwargs later + if backend is not None: + if backend == "triton": + kw["triton"] = True + kw["foreach"] = False + elif backend == "torch": + kw["triton"] = False + kw["foreach"] = False + elif backend == "foreach": + kw["triton"] = False + kw["foreach"] = True + else: + raise ValueError(f"Unknown backend: {backend}") + return kw + + def to_reference_kwargs(self) -> dict[str, Any]: + assert self.reference_params is not None + return self.reference_params.to_reference_kwargs(self.reference_class) + + def supports_l2_weight_decay(self) -> bool: + import inspect + + return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters + + +def default_variants(base: Case) -> list[Case]: + """Generate base + L2 + decoupled variants with minimal boilerplate.""" + out: list[Case] = [] + + base0 = Case( + name=f"{base.optimizer_name}_base", + optimi_class=base.optimi_class, + optimi_params=base.optimi_params.with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), + reference_class=base.reference_class, + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), + test_decoupled_wd=base.test_decoupled_wd, + skip_tests=list(base.skip_tests), + any_precision=base.any_precision, + custom_iterations=base.custom_iterations, + custom_tolerances=base.custom_tolerances, + only_dtypes=base.only_dtypes, + fully_decoupled_reference=base.fully_decoupled_reference, + ) + out.append(base0) + + import inspect + + # L2 (coupled) if optimizer supports decouple_wd arg + if "decouple_wd" in inspect.signature(base.optimi_class.__init__).parameters: + out.append( + replace( + base0, + name=f"{base.optimizer_name}_l2_wd", + optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=False), + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=False), + ) + ) + + # Decoupled weight decay + if base.test_decoupled_wd: + out.append( + replace( + base0, + name=f"{base.optimizer_name}_decoupled_wd", + optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=True), + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=True), + ) + ) + + # Decoupled LR (optionally swap reference class) + ref_cls = base.fully_decoupled_reference or base.reference_class + out.append( + replace( + base0, + name=f"{base.optimizer_name}_decoupled_lr", + optimi_params=base.optimi_params.with_(weight_decay=1e-5, decouple_lr=True), + reference_class=ref_cls, + reference_params=(base.reference_params or base.optimi_params).with_( + weight_decay=1e-5 if base.fully_decoupled_reference else 0.01, + decouple_lr=True, + ), + ) + ) + return out + + +def discover_cases(root: Path | None = None) -> list[Case]: + """ + Discover `opt_*.py` modules in this package. Accept exactly: + - TESTS: list[Case] + - BASE: Case -> expanded via default_variants(BASE) + """ + if root is None: + root = Path(__file__).parent + cases: list[Case] = [] + for f in root.glob("opt_*.py"): + mod = importlib.import_module(f".{f.stem}", package=__package__) + if hasattr(mod, "TESTS"): + cases.extend(getattr(mod, "TESTS")) + elif hasattr(mod, "BASE"): + base = getattr(mod, "BASE") + cases.extend(default_variants(base)) + return cases + + +def optimizer_names() -> list[str]: + return sorted({c.optimizer_name for c in discover_cases()}) diff --git a/test_new/conftest.py b/test_new/conftest.py index 5fb7f39..323fdec 100644 --- a/test_new/conftest.py +++ b/test_new/conftest.py @@ -1,14 +1,13 @@ """Pytest configuration and fixtures for the unified optimizer test framework. -This module provides pytest configuration, custom mark registration, and fixtures -for running optimizer tests across different devices, dtypes, and backends. +This module provides pytest configuration, custom mark registration, and the +`gpu_device` fixture used by tests. """ import pytest import torch -from packaging import version -from .optimizer_tests import get_all_optimizer_names +from .cases import optimizer_names def pytest_configure(config): @@ -26,110 +25,53 @@ def pytest_configure(config): config.addinivalue_line("markers", "torch: mark test to run with torch backend") config.addinivalue_line("markers", "triton: mark test to run with triton backend") - optimizer_names = get_all_optimizer_names() - for optimizer_name in optimizer_names: - config.addinivalue_line("markers", f"{optimizer_name}: mark test for {optimizer_name} optimizer") + # Per-optimizer marks (e.g., -m adam, -m sgd) + for opt_name in optimizer_names(): + config.addinivalue_line("markers", f"{opt_name}: mark test for {opt_name} optimizer") -# Check for minimum PyTorch version for Triton support -MIN_TORCH_2_6 = version.parse("2.6.0") -CURRENT_TORCH_VERSION = version.parse(torch.__version__.split("+")[0]) # Remove any +cu118 suffix -HAS_TRITON_SUPPORT = CURRENT_TORCH_VERSION >= MIN_TORCH_2_6 +def pytest_addoption(parser): + """Add command-line option to specify a single GPU""" + parser.addoption("--gpu-id", action="store", type=int, default=None, help="Specify a single GPU to use (e.g. --gpu-id=0)") -@pytest.fixture(scope="session") -def gpu_device(): - """Provide GPU device for testing if available. +@pytest.fixture() +def gpu_device(worker_id, request): + """Map xdist workers to available GPU devices in a round-robin fashion, + supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. + Use a single specified GPU if --gpu-id is provided""" - Returns: - torch.device: GPU device (cuda, xpu, or mps) if available, otherwise None. - """ + # Check if specific GPU was requested + specific_gpu = request.config.getoption("--gpu-id") + + # Determine available GPU backend and device count if torch.cuda.is_available(): - return torch.device("cuda") + backend = "cuda" + device_count = torch.cuda.device_count() elif hasattr(torch, "xpu") and torch.xpu.is_available(): - return torch.device("xpu") + backend = "xpu" + device_count = torch.xpu.device_count() elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - return torch.device("mps") + backend = "mps" + device_count = 0 else: - return None - - -@pytest.fixture(scope="session") -def has_gpu(gpu_device): - """Check if GPU is available for testing. - - Returns: - bool: True if GPU device is available, False otherwise. - """ - return gpu_device is not None - - -@pytest.fixture(scope="session") -def has_triton(): - """Check if Triton backend is available. - - Returns: - bool: True if Triton is supported (PyTorch >= 2.6), False otherwise. - """ - return HAS_TRITON_SUPPORT - - -@pytest.fixture -def tolerance_config(): - """Provide default tolerance configuration for numerical comparisons. - - Returns: - dict: Default tolerance settings for different dtypes. - """ - from .framework import ToleranceConfig - - return { - torch.float32: ToleranceConfig(rtol=1e-5, atol=1e-8), - torch.bfloat16: ToleranceConfig(rtol=1e-3, atol=1e-5), # More relaxed for bfloat16 - } - - -@pytest.fixture -def cpu_device(): - """Provide CPU device for testing. - - Returns: - torch.device: CPU device. - """ - return torch.device("cpu") - - -def pytest_collection_modifyitems(config, items): - """Modify test collection to add automatic skipping for unavailable resources.""" - - # Skip GPU tests if no GPU is available - if not torch.cuda.is_available() and not (hasattr(torch, "xpu") and torch.xpu.is_available()): - skip_gpu = pytest.mark.skip(reason="GPU not available") - for item in items: - if "gpu" in item.keywords: - item.add_marker(skip_gpu) + # Fallback to cuda for compatibility + backend = "cuda" + device_count = 0 - # Skip Triton tests if not supported - if not HAS_TRITON_SUPPORT: - skip_triton = pytest.mark.skip(reason=f"Triton requires PyTorch >= {MIN_TORCH_2_6}, got {CURRENT_TORCH_VERSION}") - for item in items: - if "triton" in item.keywords: - item.add_marker(skip_triton) + if specific_gpu is not None: + return torch.device(f"{backend}:{specific_gpu}") + if worker_id == "master": + return torch.device(backend) -def pytest_runtest_setup(item): - """Setup hook to perform additional test skipping based on marks.""" + # If no devices available, return default backend + if device_count == 0: + return torch.device(backend) - # Skip GPU tests on CPU-only systems - if "gpu" in item.keywords: - gpu_available = ( - torch.cuda.is_available() - or (hasattr(torch, "xpu") and torch.xpu.is_available()) - or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) - ) - if not gpu_available: - pytest.skip("GPU not available") + # Extract worker number from worker_id (e.g., 'gw6' -> 6) + worker_num = int(worker_id.replace("gw", "")) - # Skip Triton tests if not supported - if "triton" in item.keywords and not HAS_TRITON_SUPPORT: - pytest.skip(f"Triton requires PyTorch >= {MIN_TORCH_2_6}, got {CURRENT_TORCH_VERSION}") + # Map worker to GPU index using modulo to round-robin + gpu_idx = (worker_num - 1) % device_count + return torch.device(f"{backend}:{gpu_idx}") diff --git a/test_new/framework.py b/test_new/framework.py deleted file mode 100644 index bb71b62..0000000 --- a/test_new/framework.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Core framework components for the unified optimizer test system. - -This module provides the foundational dataclasses and utilities for defining -and executing optimizer tests in a type-safe, self-contained manner. -""" - -import inspect -from copy import deepcopy -from dataclasses import asdict, dataclass, field -from typing import Any - -import torch -from optimi.optimizer import OptimiOptimizer -from torch.optim.optimizer import Optimizer - - -@dataclass -class ToleranceConfig: - """Tolerance configuration for numerical comparisons.""" - - atol: float = 1e-6 - rtol: float = 1e-5 - max_error_rate: float = 0.0005 - equal_nan: bool = False - - -@dataclass -class BaseParams: - """Base class for all optimizer parameters with common fields.""" - - lr: float = 1e-3 - weight_decay: float = 0.0 - decouple_wd: bool = False - decouple_lr: bool = False - triton: bool = False - - def _filter_kwargs_for_class(self, optimizer_class: type) -> dict[str, Any]: - """Filter parameters based on optimizer signature inspection.""" - if optimizer_class is None: - return {} - - # Get the optimizer's __init__ signature - sig = inspect.signature(optimizer_class.__init__) - valid_params = set(sig.parameters.keys()) - {"self"} - - # Filter our parameters to only include those accepted by the optimizer - return {k: v for k, v in asdict(self).items() if k in valid_params} - - def to_optimi_kwargs(self, optimi_class: type) -> dict[str, Any]: - """Convert to kwargs for optimi optimizer.""" - return self._filter_kwargs_for_class(optimi_class) - - def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: - """Convert to kwargs for reference optimizer.""" - return self._filter_kwargs_for_class(reference_class) - - -@dataclass -class OptimizerTest: - """Complete self-contained optimizer test case.""" - - # Test identification - name: str # "adam_base", "sgd_momentum", etc. - - # Optimizer classes and parameters - optimi_class: type[OptimiOptimizer] - optimi_params: BaseParams - reference_class: type[Optimizer] - reference_params: BaseParams | None = None - - # Optional fully decoupled reference - fully_decoupled_reference: Optimizer | None = None - - # Test behavior overrides (optional) - test_decoupled_wd: bool = True - skip_tests: list[str] = field(default_factory=list) - any_precision: bool = False - custom_iterations: dict[str, int] | None = None - custom_tolerances: dict[torch.dtype, ToleranceConfig] | None = None - # Optional constraints - only_dtypes: list[torch.dtype] | None = None - - def __post_init__(self): - """Post-initialization checks and adjustments.""" - if self.reference_params is None: - self.reference_params = deepcopy(self.optimi_params) - - if self.custom_tolerances is None: - self.custom_tolerances = {} - if self.custom_tolerances.get(torch.float32, None) is None: - self.custom_tolerances[torch.float32] = ToleranceConfig() - if self.custom_tolerances.get(torch.bfloat16, None) is None: - self.custom_tolerances[torch.bfloat16] = ToleranceConfig(atol=1e-3, rtol=1e-2, max_error_rate=0.01) - if self.custom_tolerances.get(torch.float16, None) is None: - self.custom_tolerances[torch.float16] = ToleranceConfig(atol=1e-4, rtol=1e-3, max_error_rate=0.01) - - @property - def optimizer_name(self) -> str: - """Extract optimizer name from test name (e.g., 'adam' from 'adam_base').""" - return self.name.split("_")[0] if "_" in self.name else self.name - - @property - def variant_name(self) -> str: - """Extract variant name from test name (e.g., 'base' from 'adam_base').""" - parts = self.name.split("_", 1) - return parts[1] if len(parts) > 1 else "base" - - def to_optimi_kwargs(self) -> dict[str, Any]: - """Get kwargs for optimi optimizer.""" - return self.optimi_params.to_optimi_kwargs(self.optimi_class) - - def to_reference_kwargs(self) -> dict[str, Any]: - """Get kwargs for reference optimizer.""" - return self.reference_params.to_reference_kwargs(self.reference_class) - - def should_skip_test(self, test_type: str) -> bool: - """Check if a specific test type should be skipped.""" - return test_type in self.skip_tests - - def get_tolerance(self, dtype: torch.dtype) -> ToleranceConfig: - """Get tolerance configuration for specific dtype.""" - return self.custom_tolerances[dtype] - - # Backwards-compatible alias to support existing call sites - def get_tolerance_for_dtype(self, dtype: torch.dtype) -> ToleranceConfig: - """Backward-compatible alias for get_tolerance.""" - return self.get_tolerance(dtype) - - def get_iterations_for_test(self, test_type: str) -> int: - """Get number of iterations for specific test type.""" - if self.custom_iterations and test_type in self.custom_iterations: - return self.custom_iterations[test_type] - - # Default iterations based on test type - defaults = {"correctness": 10, "gradient_release": 5, "accumulation": 5} - return defaults.get(test_type, 10) - - def supports_l2_weight_decay(self) -> bool: - """Check if optimizer supports L2 weight decay.""" - # all optimi optimizers which support l2 weight decay have a decouple_wd parameter - return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters - - -def assert_most_approx_close( - a: torch.Tensor, - b: torch.Tensor, - rtol: float = 1e-3, - atol: float = 1e-3, - max_error_count: int = 0, - max_error_rate: float | None = None, - name: str = "", -) -> None: - """Assert that most values in two tensors are approximately close. - - Allows for a small number of errors based on max_error_count and max_error_rate. - """ - idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) - error_count = (idx == 0).sum().item() - - if max_error_rate is not None: - if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) - elif error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py index c9674ea..4260617 100644 --- a/test_new/opt_adam.py +++ b/test_new/opt_adam.py @@ -1,25 +1,26 @@ -"""Adam optimizer test definitions.""" +"""Adam optimizer definitions using the new Case/variants flow.""" from dataclasses import dataclass import optimi import torch -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class AdamParams(BaseParams): - """Type-safe Adam optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.99) eps: float = 1e-6 -BASE_TEST = OptimizerTest( - name="adam_base", +# Provide BASE so the framework generates base/l2/decoupled variants as applicable. +# For Adam, we disable decoupled WD/LR generation to match prior behavior. +BASE = Case( + name="adam", optimi_class=optimi.Adam, - optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), + optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.0), reference_class=torch.optim.Adam, + reference_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.0), test_decoupled_wd=False, ) diff --git a/test_new/opt_adamw.py b/test_new/opt_adamw.py index c7ce386..c4d9525 100644 --- a/test_new/opt_adamw.py +++ b/test_new/opt_adamw.py @@ -1,4 +1,4 @@ -"""AdamW optimizer test definitions.""" +"""AdamW optimizer definitions using the new Case/variants flow.""" from dataclasses import dataclass @@ -6,21 +6,21 @@ import torch from tests import reference -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class AdamWParams(BaseParams): - """Type-safe AdamW optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.99) eps: float = 1e-6 -BASE_TEST = OptimizerTest( - name="adamw_base", +# Provide BASE with fully_decoupled_reference so decoupled_lr uses DecoupledAdamW +BASE = Case( + name="adamw", optimi_class=optimi.AdamW, - optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), + optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), reference_class=torch.optim.AdamW, + reference_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), fully_decoupled_reference=reference.DecoupledAdamW, ) diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py index beaba77..302e9e2 100644 --- a/test_new/opt_adan.py +++ b/test_new/opt_adan.py @@ -1,4 +1,4 @@ -"""Adan optimizer test definitions.""" +"""Adan optimizer tests using the new Case format (manual list).""" from dataclasses import dataclass from typing import Any @@ -6,69 +6,51 @@ import optimi from tests import reference -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class AdanParams(BaseParams): - """Type-safe Adan optimizer parameters.""" - betas: tuple[float, float, float] = (0.98, 0.92, 0.99) eps: float = 1e-8 weight_decouple: bool = False # For adam_wd variant (maps to no_prox in reference) adam_wd: bool = False # For optimi optimizer def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: - """Adan needs special parameter conversion for no_prox.""" kwargs = super().to_reference_kwargs(reference_class) - - # Convert weight_decouple to no_prox for reference optimizer if "weight_decouple" in kwargs: kwargs["no_prox"] = kwargs.pop("weight_decouple") - - # Remove adam_wd as it's not used by reference kwargs.pop("adam_wd", None) - return kwargs -# Define all Adan test variants explicitly to match original tests -ALL_TESTS = [ - OptimizerTest( +TESTS = [ + Case( name="adan_base", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=0), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6), - custom_iterations={"correctness": 20}, # Adan bfloat16 updates are noisier, use fewer iterations for GPU ), - OptimizerTest( + Case( name="adan_weight_decay", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), - custom_iterations={"correctness": 20}, ), - OptimizerTest( + Case( name="adan_adam_wd", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, adam_wd=True), reference_class=reference.Adan, - reference_params=AdanParams( - lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, weight_decouple=True - ), # no_prox=True in reference - custom_iterations={"correctness": 20}, + reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, weight_decouple=True), ), - OptimizerTest( + Case( name="adan_decoupled_lr", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-5, decouple_lr=True), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), - custom_iterations={"correctness": 20}, ), ] - -# Set BASE_TEST for auto-generation compatibility -BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py index 721f77a..df0bb6f 100644 --- a/test_new/opt_anyadam.py +++ b/test_new/opt_anyadam.py @@ -1,4 +1,4 @@ -"""AnyAdam optimizer test definitions for Kahan summation precision tests.""" +"""AnyAdam optimizer tests using the new Case format (manual list).""" from dataclasses import dataclass @@ -6,104 +6,56 @@ import torch from tests.reference import AnyPrecisionAdamW -from .framework import BaseParams, OptimizerTest, ToleranceConfig +from .cases import BaseParams, Case, Tolerance @dataclass class AnyAdamParams(BaseParams): - """Type-safe AnyAdam optimizer parameters with Kahan summation support.""" - betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 kahan_sum: bool = False use_kahan_summation: bool = False def to_reference_kwargs(self, reference_class: type) -> dict: - """Convert parameters for AnyPrecisionAdamW reference.""" kwargs = super().to_reference_kwargs(reference_class) - - # AnyPrecisionAdamW uses use_kahan_summation instead of kahan_sum if "kahan_sum" in kwargs: kwargs["use_kahan_summation"] = kwargs.pop("kahan_sum") - - # Set default precision dtypes for AnyPrecisionAdamW if reference_class.__name__ == "AnyPrecisionAdamW": kwargs.setdefault("momentum_dtype", torch.bfloat16) kwargs.setdefault("variance_dtype", torch.bfloat16) kwargs.setdefault("compensation_buffer_dtype", torch.bfloat16) - return kwargs -ALL_TESTS = [ - OptimizerTest( +TESTS = [ + Case( name="anyadam_kahan", optimi_class=optimi.Adam, - optimi_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=0, - kahan_sum=True, - ), + optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=0, - use_kahan_summation=True, - ), + reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=2e-2, atol=2e-3, equal_nan=False)}, + custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, ), - OptimizerTest( + Case( name="anyadam_kahan_wd", - optimi_class=optimi.Adam, - optimi_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=0.01, - kahan_sum=True, - ), + optimi_class=optimi.AdamW, + optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.01, kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=0.01, - use_kahan_summation=True, - ), + reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.01, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=5e-2, atol=1e-2, equal_nan=False)}, + custom_tolerances={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01, equal_nan=False)}, ), - OptimizerTest( + Case( name="anyadam_kahan_decoupled_lr", - optimi_class=optimi.Adam, - optimi_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=1e-5, - decouple_lr=True, - kahan_sum=True, - ), + optimi_class=optimi.AdamW, + optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5, decouple_lr=True, kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams( - lr=1e-3, - betas=(0.9, 0.99), - eps=1e-6, - weight_decay=1e-2, - use_kahan_summation=True, - ), + reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: ToleranceConfig(rtol=2e-2, atol=2e-3, equal_nan=False)}, + custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, ), ] - -# For compatibility with auto-generation system -BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_lion.py b/test_new/opt_lion.py index 9d7d0e5..f9092c4 100644 --- a/test_new/opt_lion.py +++ b/test_new/opt_lion.py @@ -1,46 +1,34 @@ -"""Lion optimizer test definitions.""" +"""Lion optimizer tests in new Case format (manual list to match prior values).""" from dataclasses import dataclass import optimi from tests.reference import lion as reference_lion -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class LionParams(BaseParams): - """Type-safe Lion optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.99) -BASE_TEST = OptimizerTest( - name="lion_base", - optimi_class=optimi.Lion, - optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), - reference_class=reference_lion.Lion, - reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), -) - - -# Define all Adan test variants explicitly to match original tests -ALL_TESTS = [ - OptimizerTest( +TESTS = [ + Case( name="lion_base", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), reference_class=reference_lion.Lion, reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), ), - OptimizerTest( + Case( name="lion_decoupled_wd", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1, decouple_wd=True), reference_class=reference_lion.Lion, reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), ), - OptimizerTest( + Case( name="lion_decoupled_lr", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True), diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index 03f187d..07d44d7 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -1,4 +1,4 @@ -"""RAdam optimizer test definitions.""" +"""RAdam optimizer definitions using the new Case/variants flow.""" import inspect from dataclasses import dataclass, field @@ -6,30 +6,26 @@ import optimi import torch -from .framework import BaseParams, OptimizerTest, ToleranceConfig +from .cases import BaseParams, Case, Tolerance @dataclass class RAdamParams(BaseParams): - """Type-safe RAdam optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.99) eps: float = 1e-8 decoupled_weight_decay: bool = field(default=False) def __post_init__(self): - if self.decouple_wd: - self.decoupled_weight_decay = True - elif self.decouple_lr: + if self.decouple_wd or self.decouple_lr: self.decoupled_weight_decay = True -BASE_TEST = OptimizerTest( - name="radam_base", +BASE = Case( + name="radam", optimi_class=optimi.RAdam, optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), reference_class=torch.optim.RAdam, reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), - custom_tolerances={torch.float32: ToleranceConfig(max_error_rate=0.001)}, + custom_tolerances={torch.float32: Tolerance(max_error_rate=0.001)}, test_decoupled_wd="decoupled_weight_decay" in inspect.signature(torch.optim.RAdam.__init__).parameters, ) diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py index 89a8521..4c48d3c 100644 --- a/test_new/opt_ranger.py +++ b/test_new/opt_ranger.py @@ -1,34 +1,27 @@ -"""Ranger optimizer test definitions.""" +"""Ranger optimizer tests using new Case format (base only).""" from dataclasses import dataclass import optimi from tests import reference -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class RangerParams(BaseParams): - """Type-safe Ranger optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.99) eps: float = 1e-8 k: int = 6 # Lookahead steps alpha: float = 0.5 # Lookahead alpha -# Ranger only has base test - reference doesn't perform normal weight decay step -ALL_TESTS = [ - OptimizerTest( +TESTS = [ + Case( name="ranger_base", optimi_class=optimi.Ranger, optimi_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), reference_class=reference.Ranger, reference_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), - custom_iterations={"gradient_release": 160}, # Ranger needs longer testing due to lookahead step ) ] - -# Set BASE_TEST for auto-generation compatibility -BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py index 8845edd..518e8fb 100644 --- a/test_new/opt_sgd.py +++ b/test_new/opt_sgd.py @@ -1,4 +1,4 @@ -"""SGD optimizer test definitions with custom parameter handling.""" +"""SGD optimizer definitions using the new Case/variants flow (manual list).""" from dataclasses import dataclass from typing import Any @@ -7,69 +7,60 @@ import torch from tests import reference -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class SGDParams(BaseParams): - """Type-safe SGD optimizer parameters.""" - momentum: float = 0.0 dampening: bool = False # Optimi uses bool instead of float torch_init: bool = False def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: - """SGD needs special dampening conversion for reference optimizer.""" kwargs = super().to_reference_kwargs(reference_class) - # Convert dampening bool to float for reference optimizer if "dampening" in kwargs and isinstance(kwargs["dampening"], bool): kwargs["dampening"] = 0.9 if kwargs["dampening"] else 0.0 - return kwargs -# Define all SGD test variants explicitly -ALL_TESTS = [ - OptimizerTest( +# Manual list to mirror the original explicit coverage for SGD +TESTS = [ + Case( name="sgd_base", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0, dampening=False, weight_decay=0), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0, dampening=0, weight_decay=0), - skip_tests=["accumulation"], # SGD base skips accumulation tests + skip_tests=["accumulation"], ), - OptimizerTest( + Case( name="sgd_momentum", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=0), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=0), ), - OptimizerTest( + Case( name="sgd_dampening", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, weight_decay=0, torch_init=True), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=0), ), - OptimizerTest( + Case( name="sgd_weight_decay", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2), - skip_tests=["accumulation"], # SGD with L2 weight decay skips accumulation tests + skip_tests=["accumulation"], ), - OptimizerTest( + Case( name="sgd_decoupled_lr", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True), reference_class=reference.DecoupledSGDW, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=1e-5), - custom_iterations={"accumulation": 20}, # SGD uses fewer iterations for accumulation ), ] - -# Set BASE_TEST for auto-generation compatibility -BASE_TEST = ALL_TESTS[0] diff --git a/test_new/opt_stableadamw.py b/test_new/opt_stableadamw.py index 9a855e4..87db705 100644 --- a/test_new/opt_stableadamw.py +++ b/test_new/opt_stableadamw.py @@ -1,23 +1,21 @@ -"""StableAdamW optimizer test definitions.""" +"""StableAdamW optimizer definitions using new Case/variants flow.""" from dataclasses import dataclass import optimi from tests import reference -from .framework import BaseParams, OptimizerTest +from .cases import BaseParams, Case @dataclass class StableAdamWParams(BaseParams): - """Type-safe StableAdamW optimizer parameters.""" - betas: tuple[float, float] = (0.9, 0.999) eps: float = 1e-8 -BASE_TEST = OptimizerTest( - name="stableadamw_base", +BASE = Case( + name="stableadamw", optimi_class=optimi.StableAdamW, optimi_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), reference_class=reference.StableAdamWUnfused, diff --git a/test_new/optimizer_tests.py b/test_new/optimizer_tests.py deleted file mode 100644 index 6e139df..0000000 --- a/test_new/optimizer_tests.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Automatic test discovery and generation system for optimizer tests. - -This module provides functionality to automatically discover optimizer test definitions -from test files and generate torch test variants, creating a comprehensive test registry. -""" - -import importlib -import warnings -from pathlib import Path -from typing import Any - -from .framework import OptimizerTest - - -def discover_optimizer_tests() -> list[OptimizerTest]: - """Automatically discover and generate tests from all test modules. - - Scans the test_new directory for Python files containing optimizer test definitions. - Each module should define either BASE_TEST (for auto-generation) or ALL_TESTS (custom). - - Returns: - List of all discovered and generated OptimizerTest instances. - """ - all_tests = [] - test_dir = Path(__file__).parent - test_files = [f for f in test_dir.glob("opt_*.py") if f.is_file()] - - for test_file in test_files: - module_name = test_file.stem - try: - # Import the test module using package-relative import - # __package__ will be 'optimi.test_new' when this file is imported as a package module - module = importlib.import_module(f".{module_name}", package=__package__) - - # Check for ALL_TESTS first (custom test definitions) - if hasattr(module, "ALL_TESTS"): - tests = getattr(module, "ALL_TESTS") - if isinstance(tests, list) and all(isinstance(t, OptimizerTest) for t in tests): - all_tests.extend(tests) - else: - warnings.warn(f"Module {module_name} has ALL_TESTS but it's not a list of OptimizerTest instances") - - # Check for BASE_TEST (for auto-generation) - elif hasattr(module, "BASE_TEST"): - base_test = getattr(module, "BASE_TEST") - if isinstance(base_test, OptimizerTest): - # Generate torch variants from base test - generated_tests = auto_generate_variants(base_test) - all_tests.extend(generated_tests) - else: - warnings.warn(f"Module {module_name} has BASE_TEST but it's not an OptimizerTest instance") - - else: - warnings.warn(f"Module {module_name} has neither BASE_TEST nor ALL_TESTS defined") - - except ImportError as e: - warnings.warn(f"Failed to import test module {module_name}: {e}") - except Exception as e: - warnings.warn(f"Error processing test module {module_name}: {e}") - - return all_tests - - -def auto_generate_variants(base_test: OptimizerTest) -> list[OptimizerTest]: - """Automatically generate torch test variants from a base test. - - Generates the following torch variants: - - base: Original test with no weight decay - - weight_decay: Test with weight decay enabled - - decoupled_wd: Test with decoupled weight decay - - decoupled_lr: Test with decoupled learning rate - - Args: - base_test: Base OptimizerTest to generate variants from. - - Returns: - List of generated test variants. - """ - variants = [] - - # Base variant (ensure weight_decay is 0) - base_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.0) - base_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.0) - - base_variant = OptimizerTest( - name=f"{base_test.optimizer_name}_base", - optimi_class=base_test.optimi_class, - optimi_params=base_params, - reference_class=base_test.reference_class, - reference_params=base_ref_params, - skip_tests=base_test.skip_tests.copy(), - only_dtypes=base_test.only_dtypes, - any_precision=base_test.any_precision, - custom_iterations=base_test.custom_iterations, - custom_tolerances=base_test.custom_tolerances, - ) - variants.append(base_variant) - - # L2 weight decay variant - if base_test.supports_l2_weight_decay(): - l2_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.01) - l2_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01) - - l2_variant = OptimizerTest( - name=f"{base_test.optimizer_name}_l2_wd", - optimi_class=base_test.optimi_class, - optimi_params=l2_params, - reference_class=base_test.reference_class, - reference_params=l2_ref_params, - skip_tests=base_test.skip_tests.copy(), - only_dtypes=base_test.only_dtypes, - any_precision=base_test.any_precision, - custom_iterations=base_test.custom_iterations, - custom_tolerances=base_test.custom_tolerances, - ) - variants.append(l2_variant) - - # Decoupled weight decay variant - if base_test.test_decoupled_wd: - decoupled_wd_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=0.01, decouple_wd=True) - decoupled_wd_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01, decouple_wd=True) - - decoupled_wd_variant = OptimizerTest( - name=f"{base_test.optimizer_name}_decoupled_wd", - optimi_class=base_test.optimi_class, - optimi_params=decoupled_wd_params, - reference_class=base_test.reference_class, - reference_params=decoupled_wd_ref_params, - skip_tests=base_test.skip_tests.copy(), - only_dtypes=base_test.only_dtypes, - any_precision=base_test.any_precision, - custom_iterations=base_test.custom_iterations, - custom_tolerances=base_test.custom_tolerances, - ) - variants.append(decoupled_wd_variant) - - # Decoupled learning rate variant - decoupled_lr_params = _copy_params_with_overrides(base_test.optimi_params, weight_decay=1e-5, decouple_lr=True) - if base_test.fully_decoupled_reference is not None: - decoupled_lr_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=1e-5, decouple_lr=True) - reference_class = base_test.fully_decoupled_reference - else: - decoupled_lr_ref_params = _copy_params_with_overrides(base_test.reference_params, weight_decay=0.01, decouple_lr=True) - reference_class = base_test.reference_class - - decoupled_lr_variant = OptimizerTest( - name=f"{base_test.optimizer_name}_decoupled_lr", - optimi_class=base_test.optimi_class, - optimi_params=decoupled_lr_params, - reference_class=reference_class, - reference_params=decoupled_lr_ref_params, - skip_tests=base_test.skip_tests.copy(), - only_dtypes=base_test.only_dtypes, - any_precision=base_test.any_precision, - custom_iterations=base_test.custom_iterations, - custom_tolerances=base_test.custom_tolerances, - ) - variants.append(decoupled_lr_variant) - - return variants - - -def _copy_params_with_overrides(params: Any, **overrides: Any) -> Any: - """Create a copy of parameter dataclass with specified overrides. - - Args: - params: Original parameter dataclass instance. - **overrides: Field values to override. - - Returns: - New parameter instance with overrides applied. - """ - # Get all current field values - current_values = {} - for field_info in params.__dataclass_fields__.values(): - current_values[field_info.name] = getattr(params, field_info.name) - - # Apply overrides - current_values.update(overrides) - - # Create new instance - return type(params)(**current_values) - - -# Central registry of all discovered tests -ALL_OPTIMIZER_TESTS: list[OptimizerTest] = [] - - -def _initialize_test_registry() -> None: - """Initialize the test registry by discovering all tests.""" - global ALL_OPTIMIZER_TESTS - if not ALL_OPTIMIZER_TESTS: # Only initialize once - ALL_OPTIMIZER_TESTS = discover_optimizer_tests() - - -def get_tests_by_optimizer(optimizer_name: str) -> list[OptimizerTest]: - """Get all tests for a specific optimizer. - - Args: - optimizer_name: Name of the optimizer (e.g., 'adam', 'sgd'). - - Returns: - List of OptimizerTest instances for the specified optimizer. - """ - _initialize_test_registry() - return [test for test in ALL_OPTIMIZER_TESTS if test.optimizer_name == optimizer_name] - - -def get_tests_by_variant(variant_name: str) -> list[OptimizerTest]: - """Get all tests for a specific variant across all optimizers. - - Args: - variant_name: Name of the variant (e.g., 'base', 'weight_decay'). - - Returns: - List of OptimizerTest instances for the specified variant. - """ - _initialize_test_registry() - return [test for test in ALL_OPTIMIZER_TESTS if test.variant_name == variant_name] - - -def get_test_by_name(name: str) -> OptimizerTest | None: - """Get a specific test by its full name. - - Args: - name: Full test name (e.g., 'adam_base', 'sgd_momentum'). - - Returns: - OptimizerTest instance if found, None otherwise. - """ - _initialize_test_registry() - for test in ALL_OPTIMIZER_TESTS: - if test.name == name: - return test - return None - - -def get_all_optimizer_names() -> list[str]: - """Get list of all available optimizer names. - - Returns: - Sorted list of unique optimizer names. - """ - _initialize_test_registry() - optimizer_names = {test.optimizer_name for test in ALL_OPTIMIZER_TESTS} - return sorted(optimizer_names) - - -def get_all_variant_names() -> list[str]: - """Get list of all available variant names. - - Returns: - Sorted list of unique variant names. - """ - _initialize_test_registry() - variant_names = {test.variant_name for test in ALL_OPTIMIZER_TESTS} - return sorted(variant_names) - - -def get_test_count() -> int: - """Get total number of discovered tests. - - Returns: - Total count of OptimizerTest instances. - """ - _initialize_test_registry() - return len(ALL_OPTIMIZER_TESTS) - - -def print_test_summary() -> None: - """Print a summary of discovered tests for debugging.""" - _initialize_test_registry() - - print(f"Discovered {len(ALL_OPTIMIZER_TESTS)} optimizer tests:") - print(f"Optimizers: {', '.join(get_all_optimizer_names())}") - print(f"Variants: {', '.join(get_all_variant_names())}") - - # Group by optimizer - for optimizer_name in get_all_optimizer_names(): - tests = get_tests_by_optimizer(optimizer_name) - test_names = [test.name for test in tests] - print(f" {optimizer_name}: {', '.join(test_names)}") diff --git a/test_new/pytest_integration.py b/test_new/pytest_integration.py deleted file mode 100644 index 2af5dda..0000000 --- a/test_new/pytest_integration.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Pytest integration functions for automatic mark generation. - -This module provides functions to create pytest parameters with automatic marks -for optimizers, devices, dtypes, and backends, enabling flexible test execution. -""" - -import pytest -import torch - -from .optimizer_tests import ALL_OPTIMIZER_TESTS, _initialize_test_registry - -_CACHED_TESTS: list | None = None - - -def create_marked_optimizer_tests(): - """Create optimizer test parameters with automatic marks. - - Each test gets marked with its optimizer name for targeted test execution. - - Returns: - list[pytest.param]: List of pytest parameters with optimizer marks. - """ - # Call discover_optimizer_tests directly to avoid global variable issues - from .optimizer_tests import discover_optimizer_tests - - global _CACHED_TESTS - if _CACHED_TESTS is None: - _CACHED_TESTS = discover_optimizer_tests() - return [pytest.param(test, marks=pytest.mark.__getattr__(test.optimizer_name), id=test.name) for test in _CACHED_TESTS] - - -def create_marked_device_types(): - """Create device type parameters with marks. - - Returns: - list[pytest.param]: List of device parameters with device marks. - """ - return [ - pytest.param("cpu", marks=pytest.mark.cpu, id="cpu"), - pytest.param("gpu", marks=pytest.mark.gpu, id="gpu"), - ] - - -def create_marked_dtypes(): - """Create dtype parameters with marks. - - Only includes float32 and bfloat16 as specified in requirements. - - Returns: - list[pytest.param]: List of dtype parameters with dtype marks. - """ - return [ - pytest.param(torch.float32, marks=pytest.mark.float32, id="float32"), - pytest.param(torch.bfloat16, marks=pytest.mark.bfloat16, id="bfloat16"), - ] - - -def create_marked_backends(): - """Create backend parameters with marks. - - Only includes torch and triton backends as specified in requirements. - - Returns: - list[pytest.param]: List of backend parameters with backend marks. - """ - return [ - pytest.param("torch", marks=pytest.mark.torch, id="torch"), - pytest.param("triton", marks=pytest.mark.triton, id="triton"), - ] - - -def create_gpu_only_device_types(): - """Create device type parameters for GPU-only tests. - - Returns: - list[pytest.param]: List containing only GPU device parameter. - """ - return [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")] - - -def create_float32_only_dtypes(): - """Create dtype parameters for float32-only tests. - - Returns: - list[pytest.param]: List containing only float32 dtype parameter. - """ - return [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")] - - -def get_optimizer_marks(): - """Get all available optimizer marks. - - Returns: - list[str]: List of optimizer names that can be used as pytest marks. - """ - _initialize_test_registry() - return sorted({test.optimizer_name for test in ALL_OPTIMIZER_TESTS}) - - -def get_device_marks(): - """Get all available device marks. - - Returns: - list[str]: List of device names that can be used as pytest marks. - """ - return ["cpu", "gpu"] - - -def get_dtype_marks(): - """Get all available dtype marks. - - Returns: - list[str]: List of dtype names that can be used as pytest marks. - """ - return ["float32", "bfloat16"] - - -def get_backend_marks(): - """Get all available backend marks. - - Returns: - list[str]: List of backend names that can be used as pytest marks. - """ - return ["torch", "triton"] - - -def create_test_matrix(): - """Create the complete test matrix for all combinations. - - Returns: - dict: Dictionary containing all parameter combinations for testing. - """ - return { - "optimizer_tests": create_marked_optimizer_tests(), - "device_types": create_marked_device_types(), - "dtypes": create_marked_dtypes(), - "backends": create_marked_backends(), - "gpu_only_devices": create_gpu_only_device_types(), - "float32_only_dtypes": create_float32_only_dtypes(), - } - - -def print_mark_summary(): - """Print a summary of available marks for debugging.""" - print("Available pytest marks:") - print(f" Optimizers: {', '.join(get_optimizer_marks())}") - print(f" Devices: {', '.join(get_device_marks())}") - print(f" Dtypes: {', '.join(get_dtype_marks())}") - print(f" Backends: {', '.join(get_backend_marks())}") - - print( - f"\nTotal test combinations: {len(create_marked_optimizer_tests())} optimizers × " - f"{len(get_device_marks())} devices × {len(get_dtype_marks())} dtypes × " - f"{len(get_backend_marks())} backends" - ) diff --git a/test_new/runners.py b/test_new/runners.py new file mode 100644 index 0000000..7862fb0 --- /dev/null +++ b/test_new/runners.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import io +from typing import Optional + +import torch +from optimi import prepare_for_gradient_release, remove_gradient_release +from torch import Tensor + +from .cases import Case, Tolerance + + +def assert_most_approx_close( + a: torch.Tensor, + b: torch.Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, + max_error_count: int = 0, + max_error_rate: float | None = None, + name: str = "", +) -> None: + """Assert that most values in two tensors are approximately close. + + Allows for a small number of errors based on max_error_count and max_error_rate. + """ + idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) + error_count = (idx == 0).sum().item() + + if max_error_rate is not None: + if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + elif error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + + +class MLP(torch.nn.Module): + def __init__(self, input_size: int, hidden_size: int, device: torch.device, dtype: torch.dtype): + super().__init__() + self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) + self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) + self.act = torch.nn.Mish() + self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def run_correctness( + case: Case, + device: torch.device, + dtype: torch.dtype, + backend: str, + dims: tuple[int, int] | None = None, +) -> None: + # iterations: default parity CPU=20, GPU=40 unless overridden + iterations = case.custom_iterations.get("correctness") if case.custom_iterations else None + if iterations is None: + iterations = 20 if device.type == "cpu" else 40 + # Adan bfloat16 updates are noisier on GPU: align with tests which use 20 iters + if device.type != "cpu" and dtype == torch.bfloat16 and case.optimizer_name == "adan": + iterations = 20 + tolerance = case.custom_tolerances[dtype] + + # Dimensions and error counts + if dims is not None: + dim1, dim2 = dims + elif device.type == "cpu": + dim1, dim2 = 64, 128 + else: + dim1, dim2 = 256, 512 + + batch_size = 1 if device.type == "cpu" else 32 + max_error_count = 2 if device.type == "cpu" else 5 + + # bfloat16 error rate (kept for parity; not used directly here) + _max_error_rate: Optional[float] = 0.01 if dtype == torch.bfloat16 else None + + # Create models + m1 = MLP(dim1, dim2, device=device, dtype=dtype) + m2 = MLP(dim1, dim2, device=device, dtype=dtype) + m2.load_state_dict(m1.state_dict()) + + # Convert parameters to float for non-any_precision + if not case.any_precision and dtype != torch.float32: + for p in m1.parameters(): + p.data = p.data.float() + + # Optimizers + reference_class = case.reference_class + reference_kwargs = case.to_reference_kwargs() + optimi_kwargs = case.to_optimi_kwargs(backend) + reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) + optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + + buffer = io.BytesIO() + + for i in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.detach().clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.detach().clone() + + if not case.any_precision and dtype != torch.float32: + input1 = input1.float() + target1 = target1.float() + + output1 = m1(input1) + output2 = m2(input2) + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + loss1.backward() + loss2.backward() + + reference_optimizer.step() + optimi_optimizer.step() + reference_optimizer.zero_grad() + optimi_optimizer.zero_grad() + + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="fc1: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="fc2: ", + ) + + # state_dict save/load periodically + if i % max(1, iterations // 10) == 0 and i > 0: + torch.save(optimi_optimizer.state_dict(), buffer) + buffer.seek(0) + ckpt = torch.load(buffer, weights_only=True) + optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer.load_state_dict(ckpt) + buffer.seek(0) + buffer.truncate(0) + + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="fc1 after load: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="fc2 after load: ", + ) + + +def run_gradient_release( + case: Case, + device: torch.device, + dtype: torch.dtype, + backend: str, + dims: tuple[int, int] | None = None, +) -> None: + def optimizer_hook(parameter) -> None: + torch_optimizers[parameter].step() + torch_optimizers[parameter].zero_grad() + + # iterations: default parity GPU=40 unless overridden + iterations = case.custom_iterations.get("gradient_release") if case.custom_iterations else None + if iterations is None: + iterations = 40 + tolerance = case.custom_tolerances[dtype] + # Enforce a minimal baseline tolerance for gradient-release parity (from parity notes) + if dtype == torch.float32: + baseline = Tolerance(atol=1e-6, rtol=1e-5, max_error_rate=5e-4) + elif dtype == torch.bfloat16: + baseline = Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01) + elif dtype == torch.float16: + baseline = Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01) + else: + baseline = tolerance + tolerance = Tolerance( + rtol=max(tolerance.rtol, baseline.rtol), + atol=max(tolerance.atol, baseline.atol), + max_error_rate=max(tolerance.max_error_rate, baseline.max_error_rate), + equal_nan=tolerance.equal_nan, + ) + + max_error_count = 12 # more lenient for noisy updates + + # Dims: default 128x256 unless provided (tests also use 128x1024) + dim1, dim2 = dims if dims is not None else (128, 256) + batch_size = 32 + + m1 = MLP(dim1, dim2, device=device, dtype=dtype) # regular + m2 = MLP(dim1, dim2, device=device, dtype=dtype) # PyTorch hooks + m3 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi gradient release + m2.load_state_dict(m1.state_dict()) + m3.load_state_dict(m1.state_dict()) + + reference_class = case.reference_class + reference_kwargs = case.to_reference_kwargs() + optimi_kwargs = case.to_optimi_kwargs(backend) + + regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) + torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} + pytorch_hooks = [] + for p in m2.parameters(): + pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) + + optimi_kwargs["gradient_release"] = True + optimi_optimizer = case.optimi_class(m3.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m3, optimi_optimizer) + + for _ in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.clone() + input3 = input1.clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.clone() + target3 = target1.clone() + + output1 = m1(input1) + output2 = m2(input2) + output3 = m3(input3) + + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + loss3 = torch.nn.functional.mse_loss(output3, target3) + + loss1.backward() + loss2.backward() + loss3.backward() + + regular_optimizer.step() + regular_optimizer.zero_grad() + + # Optional framework step for determinism left disabled + # optimi_optimizer.step(); optimi_optimizer.zero_grad() + + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc1.weight, + m3.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="PyTorch-Optimi: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m3.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=tolerance.max_error_rate, + name="PyTorch-Optimi: ", + ) + + for h in pytorch_hooks: + h.remove() + remove_gradient_release(m3) + + +def run_accumulation( + case: Case, + device: torch.device, + dtype: torch.dtype, + backend: str, + dims: tuple[int, int] | None = None, +) -> None: + # iterations: default parity GPU=40 unless overridden + iterations = case.custom_iterations.get("accumulation") if case.custom_iterations else None + if iterations is None: + iterations = 40 + + max_error_rate = 0.035 + tolerance = Tolerance(rtol=1e-2, atol=1e-2) + + # Dims: default 128x256 unless provided (tests also use 128x1024) + dim1, dim2 = dims if dims is not None else (128, 256) + batch_size = 32 + + m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer + m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation + m2.load_state_dict(m1.state_dict()) + + reference_class = case.reference_class + reference_kwargs = case.to_reference_kwargs() + optimi_kwargs = case.to_optimi_kwargs(backend) + + regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) + optimi_kwargs["gradient_release"] = True + optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m2, optimi_optimizer) + + gradient_accumulation_steps = 4 + + for i in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + input2 = input1.clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + target2 = target1.clone() + + optimi_optimizer.optimizer_accumulation = (i + 1) % gradient_accumulation_steps != 0 + + output1 = m1(input1) + output2 = m2(input2) + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + + loss1.backward() + loss2.backward() + + if not optimi_optimizer.optimizer_accumulation: + regular_optimizer.step() + regular_optimizer.zero_grad() + + # Optional framework step left disabled to mirror prior behavior + # optimi_optimizer.step(); optimi_optimizer.zero_grad() + + assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) + assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) + + remove_gradient_release(m2) diff --git a/test_new/test_optimizers.py b/test_new/test_optimizers.py deleted file mode 100644 index a35dccb..0000000 --- a/test_new/test_optimizers.py +++ /dev/null @@ -1,556 +0,0 @@ -"""Main test class with comprehensive test methods for the unified optimizer test framework. - -This module provides the TestOptimizers class that implements all test types: -- test_optimizer_correctness: Validates optimizer correctness against reference implementations -- test_gradient_release: Tests gradient release functionality (GPU-only) -- test_optimizer_accumulation: Tests optimizer accumulation functionality (GPU-only) -""" - -import io - -import pytest -import torch -from optimi import prepare_for_gradient_release, remove_gradient_release -from optimi.utils import MIN_TORCH_2_6 -from torch import Tensor - -from .framework import OptimizerTest, ToleranceConfig, assert_most_approx_close -from .pytest_integration import ( - create_float32_only_dtypes, - create_gpu_only_device_types, - create_marked_backends, - create_marked_device_types, - create_marked_dtypes, - create_marked_optimizer_tests, -) - - -class MLP(torch.nn.Module): - """Simple MLP model for testing optimizer behavior.""" - - def __init__(self, input_size: int, hidden_size: int, device: torch.device, dtype: torch.dtype): - super().__init__() - self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) - self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) - self.act = torch.nn.Mish() - self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) - - def forward(self, x: Tensor) -> Tensor: - x = self.norm(x) - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return x - - -class TestOptimizers: - """Main test class for comprehensive optimizer testing.""" - - @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) - @pytest.mark.parametrize("device_type", create_marked_device_types()) - @pytest.mark.parametrize("dtype", create_marked_dtypes()) - @pytest.mark.parametrize("backend", create_marked_backends()) - def test_optimizer_correctness( - self, - optimizer_test: OptimizerTest, - device_type: str, - dtype: torch.dtype, - backend: str, - gpu_device: str, - ) -> None: - """Test optimizer correctness against reference implementation. - - Validates that the optimi optimizer produces results consistent with - the reference PyTorch optimizer across different configurations. - """ - # Skip test if conditions don't match - if self._should_skip("correctness", optimizer_test, device_type, dtype, backend): - pytest.skip(f"Skipping {optimizer_test.name} correctness test for {device_type}/{dtype}/{backend}") - - # Log a random seed for reproducibility while keeping randomness - seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) - print(f"[seed] correctness: {optimizer_test.name} {device_type}/{dtype}/{backend} -> {seed}") - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - # Determine actual device - device = torch.device(gpu_device if device_type == "gpu" else "cpu") - - # Run the correctness test - self._run_correctness_test(optimizer_test, device, dtype, backend) - - @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) - @pytest.mark.parametrize("device_type", create_gpu_only_device_types()) - @pytest.mark.parametrize("dtype", create_float32_only_dtypes()) - @pytest.mark.parametrize("backend", create_marked_backends()) - def test_gradient_release( - self, - optimizer_test: OptimizerTest, - device_type: str, - dtype: torch.dtype, - backend: str, - gpu_device: str, - ) -> None: - """Test gradient release functionality (GPU only). - - Validates that gradient release produces consistent results with - standard optimizer behavior while freeing memory during backprop. - """ - # Skip test if conditions don't match - if self._should_skip("gradient_release", optimizer_test, device_type, dtype, backend): - pytest.skip(f"Skipping {optimizer_test.name} gradient_release test for {device_type}/{dtype}/{backend}") - - # Log a random seed for reproducibility while keeping randomness - seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) - print(f"[seed] gradient_release: {optimizer_test.name} {device_type}/{dtype}/{backend} -> {seed}") - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - # Determine actual device (always GPU for this test) - device = torch.device(gpu_device) - - # Run the gradient release test - self._run_gradient_release_test(optimizer_test, device, dtype, backend) - - @pytest.mark.parametrize("optimizer_test", create_marked_optimizer_tests()) - @pytest.mark.parametrize("device_type", create_gpu_only_device_types()) - @pytest.mark.parametrize("dtype", create_float32_only_dtypes()) - @pytest.mark.parametrize("backend", create_marked_backends()) - def test_optimizer_accumulation( - self, - optimizer_test: OptimizerTest, - device_type: str, - dtype: torch.dtype, - backend: str, - gpu_device: str, - ) -> None: - """Test optimizer accumulation functionality (GPU only). - - Validates that optimizer accumulation produces results consistent - with gradient accumulation while being more memory efficient. - """ - # Skip test if conditions don't match - if self._should_skip("accumulation", optimizer_test, device_type, dtype, backend): - pytest.skip(f"Skipping {optimizer_test.name} accumulation test for {device_type}/{dtype}/{backend}") - - # Determine actual device (always GPU for this test) - device = torch.device(gpu_device) - - # Run the accumulation test - self._run_accumulation_test(optimizer_test, device, dtype, backend) - - def _prepare_kwargs(self, optimizer_test: OptimizerTest, backend: str) -> tuple[dict, dict]: - """Prepare reference and optimi kwargs including backend-specific flags.""" - reference_kwargs = optimizer_test.to_reference_kwargs() - optimi_kwargs = optimizer_test.to_optimi_kwargs() - if backend == "triton": - optimi_kwargs["triton"] = True - else: - optimi_kwargs["foreach"] = False - optimi_kwargs["triton"] = False - return reference_kwargs, optimi_kwargs - - def _should_skip( - self, - test_type: str, - optimizer_test: OptimizerTest, - device_type: str, - dtype: torch.dtype, - backend: str, - ) -> bool: - """Comprehensive test skipping logic with all conditions. - - Args: - test_type: Type of test ('correctness', 'gradient_release', 'accumulation') - optimizer_test: The optimizer test configuration - device_type: Device type ('cpu' or 'gpu') - dtype: Data type (torch.float32 or torch.bfloat16) - backend: Backend type ('torch' or 'triton') - - Returns: - True if test should be skipped, False otherwise - """ - # Check if test type is explicitly skipped - if optimizer_test.should_skip_test(test_type): - return True - - # Respect per-test dtype constraints if provided - if getattr(optimizer_test, "only_dtypes", None): - if dtype not in optimizer_test.only_dtypes: - return True - - # Skip triton tests on CPU - if backend == "triton" and device_type == "cpu": - return True - - # Skip triton tests if PyTorch version is too old - if backend == "triton" and not MIN_TORCH_2_6: - return True - - # Skip GPU tests if no GPU is available - if device_type == "gpu" and not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())): - return True - - # Gradient release and accumulation are GPU-only tests - if test_type in ["gradient_release", "accumulation"] and device_type == "cpu": - return True - - # Skip bfloat16 on CPU for most optimizers (matches original test behavior) - # Only anyadam tests bfloat16 on CPU for precision testing - if device_type == "cpu" and dtype == torch.bfloat16 and not optimizer_test.name.startswith("anyadam"): - return True - - return False - - def _run_correctness_test( - self, - optimizer_test: OptimizerTest, - device: torch.device, - dtype: torch.dtype, - backend: str, - ) -> None: - """Core correctness test implementation. - - Creates two identical models, runs them with optimi and reference optimizers, - and validates that they produce consistent results. - """ - # Get test configuration - iterations = optimizer_test.get_iterations_for_test("correctness") - tolerance = optimizer_test.get_tolerance_for_dtype(dtype) - - # Determine model dimensions and error handling based on device - if device.type == "cpu": - dim1, dim2 = 64, 128 - batch_size = 1 - max_error_count = 2 - else: - dim1, dim2 = 256, 512 - batch_size = 32 - max_error_count = 5 - - # Set max_error_rate for bfloat16 like original tests - max_error_rate = None - if dtype == torch.bfloat16: - max_error_rate = 0.01 # Allow 1% of values to be outside tolerance - - # Skip 1x1 tests - if dim1 == 1 and dim2 == 1: - pytest.skip("Skipping 1x1 optimizer test") - - # Create models - m1 = MLP(dim1, dim2, device=device, dtype=dtype) - m2 = MLP(dim1, dim2, device=device, dtype=dtype) - m2.load_state_dict(m1.state_dict()) - - # Convert model parameters to float for non-any_precision testing - if not optimizer_test.any_precision and dtype != torch.float32: - for p in m1.parameters(): - p.data = p.data.float() - - # Create optimizers - reference_class = optimizer_test.reference_class - reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) - - reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) - optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) - - # Training loop with state dict testing - buffer = io.BytesIO() - - for i in range(iterations): - # Generate training data - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.detach().clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.detach().clone() - - # Convert inputs to float for non-any_precision testing - if not optimizer_test.any_precision and dtype != torch.float32: - input1 = input1.float() - target1 = target1.float() - - # Forward pass - output1 = m1(input1) - output2 = m2(input2) - - # Loss calculation - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - - # Backward pass - loss1.backward() - loss2.backward() - - # Optimizer step - reference_optimizer.step() - optimi_optimizer.step() - - # Zero gradients - reference_optimizer.zero_grad() - optimi_optimizer.zero_grad() - - # Compare model weights - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc1: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc2: ", - ) - - # Test state_dict saving and loading periodically - if i % max(1, iterations // 10) == 0 and i > 0: - # Save optimizer state - torch.save(optimi_optimizer.state_dict(), buffer) - buffer.seek(0) - # Load checkpoint - ckpt = torch.load(buffer, weights_only=True) - # Recreate optimizer and load its state - optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) - optimi_optimizer.load_state_dict(ckpt) - # Clear buffer - buffer.seek(0) - buffer.truncate(0) - - # Verify models are still aligned after state_dict loading - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc1 after load: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc2 after load: ", - ) - - def _run_gradient_release_test( - self, - optimizer_test: OptimizerTest, - device: torch.device, - dtype: torch.dtype, - backend: str, - ) -> None: - """Core gradient release test implementation. - - Compares gradient release behavior with torch PyTorch optimizer hooks - and regular optimizer behavior to ensure consistency. - """ - - def optimizer_hook(parameter) -> None: - torch_optimizers[parameter].step() - torch_optimizers[parameter].zero_grad() - - # Get test configuration - iterations = optimizer_test.get_iterations_for_test("gradient_release") - tolerance = optimizer_test.get_tolerance_for_dtype(dtype) - - # Set default tolerances for gradient release (slightly more lenient) - if optimizer_test.custom_tolerances is None: - if dtype == torch.float32: - tolerance = ToleranceConfig(rtol=1e-5, atol=2e-6) - elif dtype == torch.bfloat16: - tolerance = ToleranceConfig(rtol=1e-2, atol=2e-3) - - # Since Lion & Adan can have noisy updates, allow up to 12 errors - max_error_count = 12 - - # Model dimensions for gradient release tests - dim1, dim2 = 128, 256 - batch_size = 32 - - # Create three identical models - m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer - m2 = MLP(dim1, dim2, device=device, dtype=dtype) # PyTorch hooks - m3 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi gradient release - m2.load_state_dict(m1.state_dict()) - m3.load_state_dict(m1.state_dict()) - - # Create optimizers - reference_class = optimizer_test.reference_class - reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) - - # Regular optimizer - regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) - - # PyTorch Method: taken from https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} - - pytorch_hooks = [] - for p in m2.parameters(): - pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) - - # Optimi Method with gradient release - optimi_kwargs["gradient_release"] = True - optimi_optimizer = optimizer_test.optimi_class(m3.parameters(), **optimi_kwargs) - prepare_for_gradient_release(m3, optimi_optimizer) - - # Training loop - for i in range(iterations): - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.clone() - input3 = input1.clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.clone() - target3 = target1.clone() - - output1 = m1(input1) - output2 = m2(input2) - output3 = m3(input3) - - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - loss3 = torch.nn.functional.mse_loss(output3, target3) - - loss1.backward() - loss2.backward() - loss3.backward() - - regular_optimizer.step() - regular_optimizer.zero_grad() - - # Simulate framework optimizer step (randomly enabled) - framework_opt_step = torch.rand(1).item() > 0.5 - if framework_opt_step: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - # Compare results - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - name="PyTorch-PyTorch: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - name="PyTorch-PyTorch: ", - ) - assert_most_approx_close( - m1.fc1.weight, - m3.fc1.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - name="PyTorch-Optimi: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m3.fc2.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - name="PyTorch-Optimi: ", - ) - - # Cleanup - for h in pytorch_hooks: - h.remove() - remove_gradient_release(m3) - - def _run_accumulation_test( - self, - optimizer_test: OptimizerTest, - device: torch.device, - dtype: torch.dtype, - backend: str, - ) -> None: - """Core accumulation test implementation. - - Tests optimizer accumulation functionality which approximates gradient - accumulation by accumulating directly into optimizer states. - """ - # Get test configuration - iterations = optimizer_test.get_iterations_for_test("accumulation") - - # Since optimizer accumulation approximates gradient accumulation, - # the tolerances are high despite the low number of iterations - max_error_rate = 0.035 - tolerance = ToleranceConfig(rtol=1e-2, atol=1e-2) - - # Model dimensions for accumulation tests - dim1, dim2 = 128, 256 - batch_size = 32 - - # Create two identical models - m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer - m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation - m2.load_state_dict(m1.state_dict()) - - # Create optimizers - reference_class = optimizer_test.reference_class - reference_kwargs, optimi_kwargs = self._prepare_kwargs(optimizer_test, backend) - - # Regular optimizer - regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) - - # Optimi optimizer with gradient release for accumulation - optimi_kwargs["gradient_release"] = True - optimi_optimizer = optimizer_test.optimi_class(m2.parameters(), **optimi_kwargs) - prepare_for_gradient_release(m2, optimi_optimizer) - - gradient_accumulation_steps = 4 - - # Training loop - for i in range(iterations): - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.clone() - - # Set accumulation mode - optimi_optimizer.optimizer_accumulation = (i + 1) % gradient_accumulation_steps != 0 - - output1 = m1(input1) - output2 = m2(input2) - - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - - loss1.backward() - loss2.backward() - - # Only step regular optimizer when not accumulating - if not optimi_optimizer.optimizer_accumulation: - regular_optimizer.step() - regular_optimizer.zero_grad() - - # Simulate framework optimizer step (randomly enabled) - framework_opt_step = torch.rand(1).item() > 0.5 - if framework_opt_step: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - # Unlike other tests, compare that the weights are in the same approximate range at the end of training - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) - - # Cleanup - remove_gradient_release(m2) diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py new file mode 100644 index 0000000..9aa8b06 --- /dev/null +++ b/test_new/test_optimizers_unified.py @@ -0,0 +1,114 @@ +import pytest +import torch +from optimi.utils import MIN_TORCH_2_6 + +from .cases import discover_cases +from .runners import run_accumulation, run_correctness, run_gradient_release + +CASES = tuple(discover_cases()) + +DEVICE_PARAMS = [ + pytest.param("cpu", marks=pytest.mark.cpu, id="cpu"), + pytest.param("gpu", marks=pytest.mark.gpu, id="gpu"), +] +DTYPE_PARAMS = [ + pytest.param(torch.float32, marks=pytest.mark.float32, id="float32"), + pytest.param(torch.bfloat16, marks=pytest.mark.bfloat16, id="bfloat16"), +] +BACKEND_PARAMS = [ + pytest.param("torch", marks=pytest.mark.torch, id="torch"), + pytest.param("triton", marks=pytest.mark.triton, id="triton"), +] + +# Attach per-optimizer marks so users can -m adam, -m sgd, etc. +OPTIM_PARAMS = [pytest.param(c, id=c.name, marks=getattr(pytest.mark, c.optimizer_name)) for c in CASES] + +# Dimension parameter spaces (match legacy tests) +# Correctness dims: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) +CORRECTNESS_DIMS = [ + pytest.param(("cpu", (64, 64)), id="cpu-64x64"), + pytest.param(("cpu", (64, 128)), id="cpu-64x128"), + pytest.param(("gpu", (256, 256)), id="gpu-256x256"), + pytest.param(("gpu", (256, 512)), id="gpu-256x512"), + pytest.param(("gpu", (256, 1024)), id="gpu-256x1024"), + pytest.param(("gpu", (256, 2048)), id="gpu-256x2048"), +] + +# Gradient release and accumulation dims (GPU-only): (128,256) and (128,1024) +GR_DIMS = [ + pytest.param((128, 256), id="gr-128x256"), + pytest.param((128, 1024), id="gr-128x1024"), +] + + +def _should_skip(test_type, case, device_type, dtype, backend) -> bool: + # Explicit per-case skip + if test_type in set(case.skip_tests): + return True + + # Respect per-test dtype constraints if provided + if case.only_dtypes and dtype not in case.only_dtypes: + return True + + # Skip triton on CPU + if backend == "triton" and device_type == "cpu": + return True + + # Triton requires torch >= 2.6 + if backend == "triton" and not MIN_TORCH_2_6: + return True + + # GPU availability + if device_type == "gpu" and not ( + torch.cuda.is_available() + or (hasattr(torch, "xpu") and torch.xpu.is_available()) + or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + ): + return True + + # Gradient release and accumulation are GPU-only tests + if test_type in ["gradient_release", "accumulation"] and device_type == "cpu": + return True + + # Skip bfloat16 on CPU for most optimizers; allow anyadam-style exceptions via case.any_precision if needed + if device_type == "cpu" and dtype == torch.bfloat16 and not case.any_precision: + return True + + return False + + +@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("device_type", DEVICE_PARAMS) +@pytest.mark.parametrize("dtype", DTYPE_PARAMS) +@pytest.mark.parametrize("backend", BACKEND_PARAMS) +@pytest.mark.parametrize("dims_spec", CORRECTNESS_DIMS) +def test_correctness(case, device_type, dtype, backend, dims_spec, gpu_device): + if _should_skip("correctness", case, device_type, dtype, backend): + pytest.skip() + dims_device, dims = dims_spec + if dims_device != device_type: + pytest.skip() + device = torch.device(gpu_device if device_type == "gpu" else "cpu") + run_correctness(case, device, dtype, backend, dims=dims) + + +@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("device_type", [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")]) +@pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) +@pytest.mark.parametrize("backend", BACKEND_PARAMS) +@pytest.mark.parametrize("dims", GR_DIMS) +def test_gradient_release(case, device_type, dtype, backend, dims, gpu_device): + if _should_skip("gradient_release", case, device_type, dtype, backend): + pytest.skip() + run_gradient_release(case, torch.device(gpu_device), dtype, backend, dims=dims) + + +@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("device_type", [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")]) +@pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) +@pytest.mark.parametrize("backend", BACKEND_PARAMS) +@pytest.mark.parametrize("dims", GR_DIMS) +def test_accumulation(case, device_type, dtype, backend, dims, gpu_device): + if _should_skip("accumulation", case, device_type, dtype, backend): + pytest.skip() + run_accumulation(case, torch.device(gpu_device), dtype, backend, dims=dims) From e6c178dafc4847295b435cd16b7755597d3da89c Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 20 Dec 2025 14:32:41 -0600 Subject: [PATCH 05/16] iterate on test_new --- test_new/cases.py | 58 ++++++++---- test_new/config.py | 61 ++++++++++++ test_new/opt_adam.py | 2 +- test_new/opt_ranger.py | 4 +- test_new/opt_sgd.py | 6 +- test_new/opt_stableadamw.py | 2 +- test_new/runners.py | 140 ++++++++++++++++------------ test_new/test_optimizers_unified.py | 61 +++++++----- 8 files changed, 224 insertions(+), 110 deletions(-) create mode 100644 test_new/config.py diff --git a/test_new/cases.py b/test_new/cases.py index 4ac6a19..c800c87 100644 --- a/test_new/cases.py +++ b/test_new/cases.py @@ -1,7 +1,7 @@ -from __future__ import annotations - import importlib -from dataclasses import dataclass, field, replace +import inspect +from dataclasses import asdict, dataclass, field, replace +from enum import Enum from pathlib import Path from typing import Any @@ -10,6 +10,23 @@ from torch.optim import Optimizer +class TestType(Enum): + correctness = "correctness" + gradient_release = "gradient_release" + accumulation = "accumulation" + + +class DeviceType(Enum): + cpu = "cpu" + gpu = "gpu" + + +class Backend(Enum): + torch = "torch" + triton = "triton" + foreach = "foreach" + + @dataclass class Tolerance: atol: float = 1e-6 @@ -30,9 +47,6 @@ def with_(self, **overrides: Any) -> "BaseParams": return replace(self, **overrides) def _kwargs_for(self, cls: type | None) -> dict[str, Any]: - import inspect - from dataclasses import asdict - if cls is None: return {} sig = inspect.signature(cls.__init__) @@ -62,9 +76,9 @@ class Case: # Behavior / constraints test_decoupled_wd: bool = True - skip_tests: list[str] = field(default_factory=list) + skip_tests: list[TestType] = field(default_factory=list) any_precision: bool = False - custom_iterations: dict[str, int] | None = None + custom_iterations: dict[TestType | tuple[TestType, DeviceType], int] | None = None custom_tolerances: dict[torch.dtype, Tolerance] | None = None only_dtypes: list[torch.dtype] | None = None @@ -86,32 +100,40 @@ def optimizer_name(self) -> str: def variant_name(self) -> str: return self.name.split("_", 1)[1] if "_" in self.name else "base" - def to_optimi_kwargs(self, backend: str | None = None) -> dict[str, Any]: + def to_optimi_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: # Both new BaseParams and legacy test_new.framework.BaseParams expose this kw = self.optimi_params.to_optimi_kwargs(self.optimi_class) # Centralize backend controls so runners don't mutate kwargs later if backend is not None: - if backend == "triton": + if backend == Backend.triton: kw["triton"] = True kw["foreach"] = False - elif backend == "torch": + elif backend == Backend.torch: kw["triton"] = False kw["foreach"] = False - elif backend == "foreach": + elif backend == Backend.foreach: kw["triton"] = False kw["foreach"] = True else: raise ValueError(f"Unknown backend: {backend}") return kw - def to_reference_kwargs(self) -> dict[str, Any]: + def to_reference_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: assert self.reference_params is not None - return self.reference_params.to_reference_kwargs(self.reference_class) + kwargs = self.reference_params.to_reference_kwargs(self.reference_class) + # Centralize fused handling for reference optimizers: when not testing + # Optimi's Triton backend, avoid fused codepaths on the reference side + # to mirror legacy parity expectations. + if backend is not None and backend != Backend.triton: + try: + if "fused" in inspect.signature(self.reference_class.__init__).parameters: + kwargs = {**kwargs, "fused": False} + except (ValueError, TypeError): + pass + return kwargs def supports_l2_weight_decay(self) -> bool: - import inspect - return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters @@ -135,9 +157,7 @@ def default_variants(base: Case) -> list[Case]: ) out.append(base0) - import inspect - - # L2 (coupled) if optimizer supports decouple_wd arg + # L2 weight decay if optimizer supports decouple_wd arg if "decouple_wd" in inspect.signature(base.optimi_class.__init__).parameters: out.append( replace( diff --git a/test_new/config.py b/test_new/config.py new file mode 100644 index 0000000..ee0d2e3 --- /dev/null +++ b/test_new/config.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import torch + +from .cases import Tolerance + + +@dataclass(frozen=True) +class CorrectnessDefaults: + cpu_iterations: int = 20 + gpu_iterations: int = 40 + # Special-case: Adan in bf16 on GPU is noisier; align to 20 + adan_bf16_gpu_iterations: int = 20 + + cpu_dims: tuple[int, int] = (64, 128) + gpu_dims: tuple[int, int] = (256, 512) + + cpu_batch_size: int = 1 + gpu_batch_size: int = 32 + + cpu_max_error_count: int = 2 + gpu_max_error_count: int = 5 + + +@dataclass(frozen=True) +class GradientReleaseDefaults: + iterations: int = 40 + dims: tuple[int, int] = (128, 256) + batch_size: int = 32 + max_error_count: int = 12 # more lenient for noisy updates + + baseline_tolerance: dict[torch.dtype, Tolerance] = field( + default_factory=lambda: { + torch.float32: Tolerance(atol=1e-6, rtol=1e-5, max_error_rate=5e-4), + torch.bfloat16: Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01), + torch.float16: Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01), + } + ) + + +@dataclass(frozen=True) +class AccumulationDefaults: + iterations: int = 40 + dims: tuple[int, int] = (128, 256) + batch_size: int = 32 + tolerance: Tolerance = field(default_factory=lambda: Tolerance(rtol=1e-2, atol=1e-2)) + max_error_rate: float = 0.035 + gradient_accumulation_steps: int = 4 + + +@dataclass(frozen=True) +class TestDefaults: + correctness: CorrectnessDefaults = CorrectnessDefaults() + gradient_release: GradientReleaseDefaults = GradientReleaseDefaults() + accumulation: AccumulationDefaults = AccumulationDefaults() + + +# Single place to tweak numbers used by runners +DEFAULTS = TestDefaults() diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py index 4260617..c9164fc 100644 --- a/test_new/opt_adam.py +++ b/test_new/opt_adam.py @@ -15,7 +15,7 @@ class AdamParams(BaseParams): # Provide BASE so the framework generates base/l2/decoupled variants as applicable. -# For Adam, we disable decoupled WD/LR generation to match prior behavior. +# Disable decoupled WD/LR generation as this is tested in AdamW tests. BASE = Case( name="adam", optimi_class=optimi.Adam, diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py index 4c48d3c..b6e440c 100644 --- a/test_new/opt_ranger.py +++ b/test_new/opt_ranger.py @@ -5,7 +5,7 @@ import optimi from tests import reference -from .cases import BaseParams, Case +from .cases import BaseParams, Case, TestType @dataclass @@ -23,5 +23,7 @@ class RangerParams(BaseParams): optimi_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), reference_class=reference.Ranger, reference_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), + # Match legacy longer gradient-release coverage due to Lookahead cadence. + custom_iterations={TestType.gradient_release: 160}, ) ] diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py index 518e8fb..c473674 100644 --- a/test_new/opt_sgd.py +++ b/test_new/opt_sgd.py @@ -7,7 +7,7 @@ import torch from tests import reference -from .cases import BaseParams, Case +from .cases import BaseParams, Case, TestType @dataclass @@ -32,7 +32,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=SGDParams(lr=1e-3, momentum=0, dampening=False, weight_decay=0), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0, dampening=0, weight_decay=0), - skip_tests=["accumulation"], + skip_tests=[TestType.accumulation], ), Case( name="sgd_momentum", @@ -54,7 +54,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2), - skip_tests=["accumulation"], + skip_tests=[TestType.accumulation], ), Case( name="sgd_decoupled_lr", diff --git a/test_new/opt_stableadamw.py b/test_new/opt_stableadamw.py index 87db705..82af545 100644 --- a/test_new/opt_stableadamw.py +++ b/test_new/opt_stableadamw.py @@ -10,7 +10,7 @@ @dataclass class StableAdamWParams(BaseParams): - betas: tuple[float, float] = (0.9, 0.999) + betas: tuple[float, float] = (0.9, 0.99) eps: float = 1e-8 diff --git a/test_new/runners.py b/test_new/runners.py index 7862fb0..ed45b18 100644 --- a/test_new/runners.py +++ b/test_new/runners.py @@ -1,13 +1,32 @@ from __future__ import annotations import io -from typing import Optional - +import random import torch from optimi import prepare_for_gradient_release, remove_gradient_release from torch import Tensor -from .cases import Case, Tolerance +from .cases import Backend, Case, DeviceType, TestType, Tolerance +from .config import DEFAULTS + + +def _device_type(device: torch.device) -> DeviceType: + return DeviceType.cpu if device.type == "cpu" else DeviceType.gpu + + +def _get_iterations( + case: Case, + test_type: TestType, + default: int, + device: torch.device | None = None, +) -> int: + if not case.custom_iterations: + return default + if device is not None: + key = (test_type, _device_type(device)) + if key in case.custom_iterations: + return case.custom_iterations[key] + return case.custom_iterations.get(test_type, default) def assert_most_approx_close( @@ -55,31 +74,25 @@ def run_correctness( case: Case, device: torch.device, dtype: torch.dtype, - backend: str, + backend: Backend, dims: tuple[int, int] | None = None, ) -> None: - # iterations: default parity CPU=20, GPU=40 unless overridden - iterations = case.custom_iterations.get("correctness") if case.custom_iterations else None - if iterations is None: - iterations = 20 if device.type == "cpu" else 40 - # Adan bfloat16 updates are noisier on GPU: align with tests which use 20 iters - if device.type != "cpu" and dtype == torch.bfloat16 and case.optimizer_name == "adan": - iterations = 20 + # Iterations and tolerance + default_iters = DEFAULTS.correctness.cpu_iterations if device.type == "cpu" else DEFAULTS.correctness.gpu_iterations + iterations = _get_iterations(case, TestType.correctness, default_iters, device=device) + # Special-case: Adan bf16 on GPU + if device.type != "cpu" and dtype == torch.bfloat16 and case.optimizer_name == "adan": + iterations = DEFAULTS.correctness.adan_bf16_gpu_iterations tolerance = case.custom_tolerances[dtype] - - # Dimensions and error counts + # Dims, batch, errors if dims is not None: dim1, dim2 = dims elif device.type == "cpu": - dim1, dim2 = 64, 128 + dim1, dim2 = DEFAULTS.correctness.cpu_dims else: - dim1, dim2 = 256, 512 - - batch_size = 1 if device.type == "cpu" else 32 - max_error_count = 2 if device.type == "cpu" else 5 - - # bfloat16 error rate (kept for parity; not used directly here) - _max_error_rate: Optional[float] = 0.01 if dtype == torch.bfloat16 else None + dim1, dim2 = DEFAULTS.correctness.gpu_dims + batch_size = DEFAULTS.correctness.cpu_batch_size if device.type == "cpu" else DEFAULTS.correctness.gpu_batch_size + max_error_count = DEFAULTS.correctness.cpu_max_error_count if device.type == "cpu" else DEFAULTS.correctness.gpu_max_error_count # Create models m1 = MLP(dim1, dim2, device=device, dtype=dtype) @@ -93,7 +106,7 @@ def run_correctness( # Optimizers reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs() + reference_kwargs = case.to_reference_kwargs(backend) optimi_kwargs = case.to_optimi_kwargs(backend) reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) @@ -175,39 +188,35 @@ def run_gradient_release( case: Case, device: torch.device, dtype: torch.dtype, - backend: str, + backend: Backend, dims: tuple[int, int] | None = None, ) -> None: def optimizer_hook(parameter) -> None: torch_optimizers[parameter].step() torch_optimizers[parameter].zero_grad() - # iterations: default parity GPU=40 unless overridden - iterations = case.custom_iterations.get("gradient_release") if case.custom_iterations else None - if iterations is None: - iterations = 40 - tolerance = case.custom_tolerances[dtype] - # Enforce a minimal baseline tolerance for gradient-release parity (from parity notes) - if dtype == torch.float32: - baseline = Tolerance(atol=1e-6, rtol=1e-5, max_error_rate=5e-4) - elif dtype == torch.bfloat16: - baseline = Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01) - elif dtype == torch.float16: - baseline = Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01) - else: - baseline = tolerance + # Iterations + iterations = _get_iterations(case, TestType.gradient_release, DEFAULTS.gradient_release.iterations, device=device) + + # Tolerances: merge baseline with per-case + tol = case.custom_tolerances[dtype] + baseline = DEFAULTS.gradient_release.baseline_tolerance.get(dtype, tol) tolerance = Tolerance( - rtol=max(tolerance.rtol, baseline.rtol), - atol=max(tolerance.atol, baseline.atol), - max_error_rate=max(tolerance.max_error_rate, baseline.max_error_rate), - equal_nan=tolerance.equal_nan, + rtol=max(tol.rtol, baseline.rtol), + atol=max(tol.atol, baseline.atol), + max_error_rate=max(tol.max_error_rate, baseline.max_error_rate), + equal_nan=tol.equal_nan, ) - max_error_count = 12 # more lenient for noisy updates + max_error_count = DEFAULTS.gradient_release.max_error_count + + # Dims and batch size + if dims is not None: + dim1, dim2 = dims + else: + dim1, dim2 = DEFAULTS.gradient_release.dims - # Dims: default 128x256 unless provided (tests also use 128x1024) - dim1, dim2 = dims if dims is not None else (128, 256) - batch_size = 32 + batch_size = DEFAULTS.gradient_release.batch_size m1 = MLP(dim1, dim2, device=device, dtype=dtype) # regular m2 = MLP(dim1, dim2, device=device, dtype=dtype) # PyTorch hooks @@ -216,7 +225,7 @@ def optimizer_hook(parameter) -> None: m3.load_state_dict(m1.state_dict()) reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs() + reference_kwargs = case.to_reference_kwargs(backend) optimi_kwargs = case.to_optimi_kwargs(backend) regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) @@ -252,8 +261,10 @@ def optimizer_hook(parameter) -> None: regular_optimizer.step() regular_optimizer.zero_grad() - # Optional framework step for determinism left disabled - # optimi_optimizer.step(); optimi_optimizer.zero_grad() + # Random step/zero_grad to simulate using optimi's accumulation in a framework like Composer + if random.random() < 0.5: + optimi_optimizer.step() + optimi_optimizer.zero_grad() assert_most_approx_close( m1.fc1.weight, @@ -301,27 +312,30 @@ def run_accumulation( case: Case, device: torch.device, dtype: torch.dtype, - backend: str, + backend: Backend, dims: tuple[int, int] | None = None, ) -> None: - # iterations: default parity GPU=40 unless overridden - iterations = case.custom_iterations.get("accumulation") if case.custom_iterations else None - if iterations is None: - iterations = 40 + # Iterations + iterations = _get_iterations(case, TestType.accumulation, DEFAULTS.accumulation.iterations, device=device) + + # Dims and batch size + if dims is not None: + dim1, dim2 = dims + else: + dim1, dim2 = DEFAULTS.accumulation.dims - max_error_rate = 0.035 - tolerance = Tolerance(rtol=1e-2, atol=1e-2) + batch_size = DEFAULTS.accumulation.batch_size - # Dims: default 128x256 unless provided (tests also use 128x1024) - dim1, dim2 = dims if dims is not None else (128, 256) - batch_size = 32 + # Tolerance and error rate + tolerance = DEFAULTS.accumulation.tolerance + max_error_rate = DEFAULTS.accumulation.max_error_rate m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation m2.load_state_dict(m1.state_dict()) reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs() + reference_kwargs = case.to_reference_kwargs(backend) optimi_kwargs = case.to_optimi_kwargs(backend) regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) @@ -329,7 +343,7 @@ def run_accumulation( optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) prepare_for_gradient_release(m2, optimi_optimizer) - gradient_accumulation_steps = 4 + gradient_accumulation_steps = DEFAULTS.accumulation.gradient_accumulation_steps for i in range(iterations): input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) @@ -351,8 +365,10 @@ def run_accumulation( regular_optimizer.step() regular_optimizer.zero_grad() - # Optional framework step left disabled to mirror prior behavior - # optimi_optimizer.step(); optimi_optimizer.zero_grad() + # Random step/zero_grad to simulate using optimi's accumulation in a framework like Composer + if random.random() < 0.5: + optimi_optimizer.step() + optimi_optimizer.zero_grad() assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index 9aa8b06..dc2ee25 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -2,22 +2,22 @@ import torch from optimi.utils import MIN_TORCH_2_6 -from .cases import discover_cases +from .cases import Backend, DeviceType, TestType, discover_cases from .runners import run_accumulation, run_correctness, run_gradient_release CASES = tuple(discover_cases()) DEVICE_PARAMS = [ - pytest.param("cpu", marks=pytest.mark.cpu, id="cpu"), - pytest.param("gpu", marks=pytest.mark.gpu, id="gpu"), + pytest.param(DeviceType.cpu, marks=pytest.mark.cpu, id=DeviceType.cpu.value), + pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value), ] DTYPE_PARAMS = [ pytest.param(torch.float32, marks=pytest.mark.float32, id="float32"), pytest.param(torch.bfloat16, marks=pytest.mark.bfloat16, id="bfloat16"), ] BACKEND_PARAMS = [ - pytest.param("torch", marks=pytest.mark.torch, id="torch"), - pytest.param("triton", marks=pytest.mark.triton, id="triton"), + pytest.param(Backend.torch, marks=pytest.mark.torch, id=Backend.torch.value), + pytest.param(Backend.triton, marks=pytest.mark.triton, id=Backend.triton.value), ] # Attach per-optimizer marks so users can -m adam, -m sgd, etc. @@ -26,12 +26,12 @@ # Dimension parameter spaces (match legacy tests) # Correctness dims: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) CORRECTNESS_DIMS = [ - pytest.param(("cpu", (64, 64)), id="cpu-64x64"), - pytest.param(("cpu", (64, 128)), id="cpu-64x128"), - pytest.param(("gpu", (256, 256)), id="gpu-256x256"), - pytest.param(("gpu", (256, 512)), id="gpu-256x512"), - pytest.param(("gpu", (256, 1024)), id="gpu-256x1024"), - pytest.param(("gpu", (256, 2048)), id="gpu-256x2048"), + pytest.param((DeviceType.cpu, (64, 64)), id="cpu-64x64"), + pytest.param((DeviceType.cpu, (64, 128)), id="cpu-64x128"), + pytest.param((DeviceType.gpu, (256, 256)), id="gpu-256x256"), + pytest.param((DeviceType.gpu, (256, 512)), id="gpu-256x512"), + pytest.param((DeviceType.gpu, (256, 1024)), id="gpu-256x1024"), + pytest.param((DeviceType.gpu, (256, 2048)), id="gpu-256x2048"), ] # Gradient release and accumulation dims (GPU-only): (128,256) and (128,1024) @@ -41,7 +41,7 @@ ] -def _should_skip(test_type, case, device_type, dtype, backend) -> bool: +def _should_skip(test_type: TestType, case, device_type: DeviceType, dtype, backend: Backend) -> bool: # Explicit per-case skip if test_type in set(case.skip_tests): return True @@ -51,15 +51,21 @@ def _should_skip(test_type, case, device_type, dtype, backend) -> bool: return True # Skip triton on CPU - if backend == "triton" and device_type == "cpu": + if backend == Backend.triton and device_type == DeviceType.cpu: return True # Triton requires torch >= 2.6 - if backend == "triton" and not MIN_TORCH_2_6: + if backend == Backend.triton and not MIN_TORCH_2_6: + return True + + # Triton is not supported on MPS + if backend == Backend.triton and not ( + torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available()) + ) and (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): return True # GPU availability - if device_type == "gpu" and not ( + if device_type == DeviceType.gpu and not ( torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available()) or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) @@ -67,11 +73,20 @@ def _should_skip(test_type, case, device_type, dtype, backend) -> bool: return True # Gradient release and accumulation are GPU-only tests - if test_type in ["gradient_release", "accumulation"] and device_type == "cpu": + if test_type in {TestType.gradient_release, TestType.accumulation} and device_type == DeviceType.cpu: + return True + + # bfloat16 is not supported on MPS + if ( + device_type == DeviceType.gpu + and dtype == torch.bfloat16 + and not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())) + and (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + ): return True # Skip bfloat16 on CPU for most optimizers; allow anyadam-style exceptions via case.any_precision if needed - if device_type == "cpu" and dtype == torch.bfloat16 and not case.any_precision: + if device_type == DeviceType.cpu and dtype == torch.bfloat16 and not case.any_precision: return True return False @@ -83,32 +98,32 @@ def _should_skip(test_type, case, device_type, dtype, backend) -> bool: @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims_spec", CORRECTNESS_DIMS) def test_correctness(case, device_type, dtype, backend, dims_spec, gpu_device): - if _should_skip("correctness", case, device_type, dtype, backend): + if _should_skip(TestType.correctness, case, device_type, dtype, backend): pytest.skip() dims_device, dims = dims_spec if dims_device != device_type: pytest.skip() - device = torch.device(gpu_device if device_type == "gpu" else "cpu") + device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") run_correctness(case, device, dtype, backend, dims=dims) @pytest.mark.parametrize("case", OPTIM_PARAMS) -@pytest.mark.parametrize("device_type", [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")]) +@pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims", GR_DIMS) def test_gradient_release(case, device_type, dtype, backend, dims, gpu_device): - if _should_skip("gradient_release", case, device_type, dtype, backend): + if _should_skip(TestType.gradient_release, case, device_type, dtype, backend): pytest.skip() run_gradient_release(case, torch.device(gpu_device), dtype, backend, dims=dims) @pytest.mark.parametrize("case", OPTIM_PARAMS) -@pytest.mark.parametrize("device_type", [pytest.param("gpu", marks=pytest.mark.gpu, id="gpu")]) +@pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims", GR_DIMS) def test_accumulation(case, device_type, dtype, backend, dims, gpu_device): - if _should_skip("accumulation", case, device_type, dtype, backend): + if _should_skip(TestType.accumulation, case, device_type, dtype, backend): pytest.skip() run_accumulation(case, torch.device(gpu_device), dtype, backend, dims=dims) From d6a90dcdd4d4e18ce67bb59c0c3e2f8f7926dc3a Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Dec 2025 18:36:25 -0600 Subject: [PATCH 06/16] small improvements --- test_new/cases.py | 32 +++++++++++++----------- test_new/opt_adam.py | 6 ++--- test_new/opt_adamw.py | 6 ++--- test_new/opt_adan.py | 12 ++++----- test_new/opt_anyadam.py | 10 ++++---- test_new/opt_lion.py | 10 ++++---- test_new/opt_radam.py | 6 ++--- test_new/opt_ranger.py | 6 ++--- test_new/opt_sgd.py | 14 +++++------ test_new/opt_stableadamw.py | 6 ++--- test_new/runners.py | 11 +++++---- test_new/test_optimizers_unified.py | 38 ++++++++++++++--------------- 12 files changed, 81 insertions(+), 76 deletions(-) diff --git a/test_new/cases.py b/test_new/cases.py index c800c87..ef700b7 100644 --- a/test_new/cases.py +++ b/test_new/cases.py @@ -1,5 +1,6 @@ import importlib import inspect +import warnings from dataclasses import asdict, dataclass, field, replace from enum import Enum from pathlib import Path @@ -51,7 +52,10 @@ def _kwargs_for(self, cls: type | None) -> dict[str, Any]: return {} sig = inspect.signature(cls.__init__) ok = set(sig.parameters) - {"self"} - return {k: v for k, v in asdict(self).items() if k in ok} + values = asdict(self) + if values.get("triton") and "triton" not in ok: + warnings.warn(f"{cls.__name__} does not accept triton; ignoring BaseParams.triton=True.", RuntimeWarning) + return {k: v for k, v in values.items() if k in ok} def to_optimi_kwargs(self, cls: type[OptimiOptimizer]) -> dict[str, Any]: return self._kwargs_for(cls) @@ -61,7 +65,7 @@ def to_reference_kwargs(self, cls: type[Optimizer]) -> dict[str, Any]: @dataclass -class Case: +class OptTest: # Identification name: str # e.g. "adam_base" @@ -137,11 +141,11 @@ def supports_l2_weight_decay(self) -> bool: return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters -def default_variants(base: Case) -> list[Case]: +def default_variants(base: OptTest) -> list[OptTest]: """Generate base + L2 + decoupled variants with minimal boilerplate.""" - out: list[Case] = [] + out: list[OptTest] = [] - base0 = Case( + base_test = OptTest( name=f"{base.optimizer_name}_base", optimi_class=base.optimi_class, optimi_params=base.optimi_params.with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), @@ -155,13 +159,13 @@ def default_variants(base: Case) -> list[Case]: only_dtypes=base.only_dtypes, fully_decoupled_reference=base.fully_decoupled_reference, ) - out.append(base0) + out.append(base_test) # L2 weight decay if optimizer supports decouple_wd arg if "decouple_wd" in inspect.signature(base.optimi_class.__init__).parameters: out.append( replace( - base0, + base_test, name=f"{base.optimizer_name}_l2_wd", optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=False), reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=False), @@ -172,7 +176,7 @@ def default_variants(base: Case) -> list[Case]: if base.test_decoupled_wd: out.append( replace( - base0, + base_test, name=f"{base.optimizer_name}_decoupled_wd", optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=True), reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=True), @@ -183,7 +187,7 @@ def default_variants(base: Case) -> list[Case]: ref_cls = base.fully_decoupled_reference or base.reference_class out.append( replace( - base0, + base_test, name=f"{base.optimizer_name}_decoupled_lr", optimi_params=base.optimi_params.with_(weight_decay=1e-5, decouple_lr=True), reference_class=ref_cls, @@ -196,15 +200,15 @@ def default_variants(base: Case) -> list[Case]: return out -def discover_cases(root: Path | None = None) -> list[Case]: +def discover_tests(root: Path | None = None) -> list[OptTest]: """ Discover `opt_*.py` modules in this package. Accept exactly: - - TESTS: list[Case] - - BASE: Case -> expanded via default_variants(BASE) + - TESTS: list[OptTest] + - BASE: OptTest -> expanded via default_variants(BASE) """ if root is None: root = Path(__file__).parent - cases: list[Case] = [] + cases: list[OptTest] = [] for f in root.glob("opt_*.py"): mod = importlib.import_module(f".{f.stem}", package=__package__) if hasattr(mod, "TESTS"): @@ -216,4 +220,4 @@ def discover_cases(root: Path | None = None) -> list[Case]: def optimizer_names() -> list[str]: - return sorted({c.optimizer_name for c in discover_cases()}) + return sorted({c.optimizer_name for c in discover_tests()}) diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py index c9164fc..99afa15 100644 --- a/test_new/opt_adam.py +++ b/test_new/opt_adam.py @@ -1,11 +1,11 @@ -"""Adam optimizer definitions using the new Case/variants flow.""" +"""Adam optimizer definitions using the new OptTest/variants flow.""" from dataclasses import dataclass import optimi import torch -from .cases import BaseParams, Case +from .cases import BaseParams, OptTest @dataclass @@ -16,7 +16,7 @@ class AdamParams(BaseParams): # Provide BASE so the framework generates base/l2/decoupled variants as applicable. # Disable decoupled WD/LR generation as this is tested in AdamW tests. -BASE = Case( +BASE = OptTest( name="adam", optimi_class=optimi.Adam, optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.0), diff --git a/test_new/opt_adamw.py b/test_new/opt_adamw.py index c4d9525..f0cbc5a 100644 --- a/test_new/opt_adamw.py +++ b/test_new/opt_adamw.py @@ -1,4 +1,4 @@ -"""AdamW optimizer definitions using the new Case/variants flow.""" +"""AdamW optimizer definitions using the new OptTest/variants flow.""" from dataclasses import dataclass @@ -6,7 +6,7 @@ import torch from tests import reference -from .cases import BaseParams, Case +from .cases import BaseParams, OptTest @dataclass @@ -16,7 +16,7 @@ class AdamWParams(BaseParams): # Provide BASE with fully_decoupled_reference so decoupled_lr uses DecoupledAdamW -BASE = Case( +BASE = OptTest( name="adamw", optimi_class=optimi.AdamW, optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py index 302e9e2..48171a5 100644 --- a/test_new/opt_adan.py +++ b/test_new/opt_adan.py @@ -1,4 +1,4 @@ -"""Adan optimizer tests using the new Case format (manual list).""" +"""Adan optimizer tests using the new OptTest format (manual list).""" from dataclasses import dataclass from typing import Any @@ -6,7 +6,7 @@ import optimi from tests import reference -from .cases import BaseParams, Case +from .cases import BaseParams, OptTest @dataclass @@ -25,28 +25,28 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: TESTS = [ - Case( + OptTest( name="adan_base", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=0), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6), ), - Case( + OptTest( name="adan_weight_decay", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), ), - Case( + OptTest( name="adan_adam_wd", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, adam_wd=True), reference_class=reference.Adan, reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, weight_decouple=True), ), - Case( + OptTest( name="adan_decoupled_lr", optimi_class=optimi.Adan, optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-5, decouple_lr=True), diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py index df0bb6f..8301f03 100644 --- a/test_new/opt_anyadam.py +++ b/test_new/opt_anyadam.py @@ -1,4 +1,4 @@ -"""AnyAdam optimizer tests using the new Case format (manual list).""" +"""AnyAdam optimizer tests using the new OptTest format (manual list).""" from dataclasses import dataclass @@ -6,7 +6,7 @@ import torch from tests.reference import AnyPrecisionAdamW -from .cases import BaseParams, Case, Tolerance +from .cases import BaseParams, OptTest, Tolerance @dataclass @@ -28,7 +28,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: TESTS = [ - Case( + OptTest( name="anyadam_kahan", optimi_class=optimi.Adam, optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, kahan_sum=True), @@ -38,7 +38,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: any_precision=True, custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, ), - Case( + OptTest( name="anyadam_kahan_wd", optimi_class=optimi.AdamW, optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.01, kahan_sum=True), @@ -48,7 +48,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: any_precision=True, custom_tolerances={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01, equal_nan=False)}, ), - Case( + OptTest( name="anyadam_kahan_decoupled_lr", optimi_class=optimi.AdamW, optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5, decouple_lr=True, kahan_sum=True), diff --git a/test_new/opt_lion.py b/test_new/opt_lion.py index f9092c4..63c1f84 100644 --- a/test_new/opt_lion.py +++ b/test_new/opt_lion.py @@ -1,11 +1,11 @@ -"""Lion optimizer tests in new Case format (manual list to match prior values).""" +"""Lion optimizer tests in new OptTest format (manual list to match prior values).""" from dataclasses import dataclass import optimi from tests.reference import lion as reference_lion -from .cases import BaseParams, Case +from .cases import BaseParams, OptTest @dataclass @@ -14,21 +14,21 @@ class LionParams(BaseParams): TESTS = [ - Case( + OptTest( name="lion_base", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), reference_class=reference_lion.Lion, reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), ), - Case( + OptTest( name="lion_decoupled_wd", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1, decouple_wd=True), reference_class=reference_lion.Lion, reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), ), - Case( + OptTest( name="lion_decoupled_lr", optimi_class=optimi.Lion, optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True), diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index 07d44d7..ed84ff3 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -1,4 +1,4 @@ -"""RAdam optimizer definitions using the new Case/variants flow.""" +"""RAdam optimizer definitions using the new OptTest/variants flow.""" import inspect from dataclasses import dataclass, field @@ -6,7 +6,7 @@ import optimi import torch -from .cases import BaseParams, Case, Tolerance +from .cases import BaseParams, OptTest, Tolerance @dataclass @@ -20,7 +20,7 @@ def __post_init__(self): self.decoupled_weight_decay = True -BASE = Case( +BASE = OptTest( name="radam", optimi_class=optimi.RAdam, optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py index b6e440c..b831704 100644 --- a/test_new/opt_ranger.py +++ b/test_new/opt_ranger.py @@ -1,11 +1,11 @@ -"""Ranger optimizer tests using new Case format (base only).""" +"""Ranger optimizer tests using new OptTest format (base only).""" from dataclasses import dataclass import optimi from tests import reference -from .cases import BaseParams, Case, TestType +from .cases import BaseParams, OptTest, TestType @dataclass @@ -17,7 +17,7 @@ class RangerParams(BaseParams): TESTS = [ - Case( + OptTest( name="ranger_base", optimi_class=optimi.Ranger, optimi_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py index c473674..642b31d 100644 --- a/test_new/opt_sgd.py +++ b/test_new/opt_sgd.py @@ -1,4 +1,4 @@ -"""SGD optimizer definitions using the new Case/variants flow (manual list).""" +"""SGD optimizer definitions using the new OptTest/variants flow (manual list).""" from dataclasses import dataclass from typing import Any @@ -7,7 +7,7 @@ import torch from tests import reference -from .cases import BaseParams, Case, TestType +from .cases import BaseParams, OptTest, TestType @dataclass @@ -26,7 +26,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: # Manual list to mirror the original explicit coverage for SGD TESTS = [ - Case( + OptTest( name="sgd_base", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0, dampening=False, weight_decay=0), @@ -34,21 +34,21 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: reference_params=SGDParams(lr=1e-3, momentum=0, dampening=0, weight_decay=0), skip_tests=[TestType.accumulation], ), - Case( + OptTest( name="sgd_momentum", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=0), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=0), ), - Case( + OptTest( name="sgd_dampening", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, weight_decay=0, torch_init=True), reference_class=torch.optim.SGD, reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=0), ), - Case( + OptTest( name="sgd_weight_decay", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False), @@ -56,7 +56,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2), skip_tests=[TestType.accumulation], ), - Case( + OptTest( name="sgd_decoupled_lr", optimi_class=optimi.SGD, optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True), diff --git a/test_new/opt_stableadamw.py b/test_new/opt_stableadamw.py index 82af545..709ea2f 100644 --- a/test_new/opt_stableadamw.py +++ b/test_new/opt_stableadamw.py @@ -1,11 +1,11 @@ -"""StableAdamW optimizer definitions using new Case/variants flow.""" +"""StableAdamW optimizer definitions using new OptTest/variants flow.""" from dataclasses import dataclass import optimi from tests import reference -from .cases import BaseParams, Case +from .cases import BaseParams, OptTest @dataclass @@ -14,7 +14,7 @@ class StableAdamWParams(BaseParams): eps: float = 1e-8 -BASE = Case( +BASE = OptTest( name="stableadamw", optimi_class=optimi.StableAdamW, optimi_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), diff --git a/test_new/runners.py b/test_new/runners.py index ed45b18..a5b1f33 100644 --- a/test_new/runners.py +++ b/test_new/runners.py @@ -2,11 +2,12 @@ import io import random + import torch from optimi import prepare_for_gradient_release, remove_gradient_release from torch import Tensor -from .cases import Backend, Case, DeviceType, TestType, Tolerance +from .cases import Backend, DeviceType, OptTest, TestType, Tolerance from .config import DEFAULTS @@ -15,7 +16,7 @@ def _device_type(device: torch.device) -> DeviceType: def _get_iterations( - case: Case, + case: OptTest, test_type: TestType, default: int, device: torch.device | None = None, @@ -71,7 +72,7 @@ def forward(self, x: Tensor) -> Tensor: def run_correctness( - case: Case, + case: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, @@ -185,7 +186,7 @@ def run_correctness( def run_gradient_release( - case: Case, + case: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, @@ -309,7 +310,7 @@ def optimizer_hook(parameter) -> None: def run_accumulation( - case: Case, + case: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index dc2ee25..de4d848 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -2,11 +2,9 @@ import torch from optimi.utils import MIN_TORCH_2_6 -from .cases import Backend, DeviceType, TestType, discover_cases +from .cases import Backend, DeviceType, OptTest, TestType, discover_tests from .runners import run_accumulation, run_correctness, run_gradient_release -CASES = tuple(discover_cases()) - DEVICE_PARAMS = [ pytest.param(DeviceType.cpu, marks=pytest.mark.cpu, id=DeviceType.cpu.value), pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value), @@ -21,11 +19,11 @@ ] # Attach per-optimizer marks so users can -m adam, -m sgd, etc. -OPTIM_PARAMS = [pytest.param(c, id=c.name, marks=getattr(pytest.mark, c.optimizer_name)) for c in CASES] +OPTIMIZERS = [pytest.param(c, id=c.name, marks=getattr(pytest.mark, c.optimizer_name)) for c in discover_tests()] # Dimension parameter spaces (match legacy tests) # Correctness dims: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) -CORRECTNESS_DIMS = [ +FULL_DIMS = [ pytest.param((DeviceType.cpu, (64, 64)), id="cpu-64x64"), pytest.param((DeviceType.cpu, (64, 128)), id="cpu-64x128"), pytest.param((DeviceType.gpu, (256, 256)), id="gpu-256x256"), @@ -34,14 +32,14 @@ pytest.param((DeviceType.gpu, (256, 2048)), id="gpu-256x2048"), ] -# Gradient release and accumulation dims (GPU-only): (128,256) and (128,1024) -GR_DIMS = [ +# Gradient release and accumulation dims: (128,256) and (128,1024) +SUBSET_DIMS = [ pytest.param((128, 256), id="gr-128x256"), pytest.param((128, 1024), id="gr-128x1024"), ] -def _should_skip(test_type: TestType, case, device_type: DeviceType, dtype, backend: Backend) -> bool: +def _should_skip(test_type: TestType, case: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: # Explicit per-case skip if test_type in set(case.skip_tests): return True @@ -59,9 +57,11 @@ def _should_skip(test_type: TestType, case, device_type: DeviceType, dtype, back return True # Triton is not supported on MPS - if backend == Backend.triton and not ( - torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available()) - ) and (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): + if ( + backend == Backend.triton + and not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())) + and (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + ): return True # GPU availability @@ -73,7 +73,7 @@ def _should_skip(test_type: TestType, case, device_type: DeviceType, dtype, back return True # Gradient release and accumulation are GPU-only tests - if test_type in {TestType.gradient_release, TestType.accumulation} and device_type == DeviceType.cpu: + if test_type in (TestType.gradient_release, TestType.accumulation) and device_type == DeviceType.cpu: return True # bfloat16 is not supported on MPS @@ -85,18 +85,18 @@ def _should_skip(test_type: TestType, case, device_type: DeviceType, dtype, back ): return True - # Skip bfloat16 on CPU for most optimizers; allow anyadam-style exceptions via case.any_precision if needed + # Skip bfloat16 on CPU for most optimizers; allow anyadam exception via case.any_precision if device_type == DeviceType.cpu and dtype == torch.bfloat16 and not case.any_precision: return True return False -@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("case", OPTIMIZERS) @pytest.mark.parametrize("device_type", DEVICE_PARAMS) @pytest.mark.parametrize("dtype", DTYPE_PARAMS) @pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims_spec", CORRECTNESS_DIMS) +@pytest.mark.parametrize("dims_spec", FULL_DIMS) def test_correctness(case, device_type, dtype, backend, dims_spec, gpu_device): if _should_skip(TestType.correctness, case, device_type, dtype, backend): pytest.skip() @@ -107,22 +107,22 @@ def test_correctness(case, device_type, dtype, backend, dims_spec, gpu_device): run_correctness(case, device, dtype, backend, dims=dims) -@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("case", OPTIMIZERS) @pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims", GR_DIMS) +@pytest.mark.parametrize("dims", SUBSET_DIMS) def test_gradient_release(case, device_type, dtype, backend, dims, gpu_device): if _should_skip(TestType.gradient_release, case, device_type, dtype, backend): pytest.skip() run_gradient_release(case, torch.device(gpu_device), dtype, backend, dims=dims) -@pytest.mark.parametrize("case", OPTIM_PARAMS) +@pytest.mark.parametrize("case", OPTIMIZERS) @pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims", GR_DIMS) +@pytest.mark.parametrize("dims", SUBSET_DIMS) def test_accumulation(case, device_type, dtype, backend, dims, gpu_device): if _should_skip(TestType.accumulation, case, device_type, dtype, backend): pytest.skip() From a20e316c830bc334a4d1e9afbc6603a9b22a5412 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Dec 2025 20:57:13 -0600 Subject: [PATCH 07/16] consolidate and simplfiy --- test_new/cases.py | 223 ---------------------------- test_new/config.py | 223 +++++++++++++++++++++++++++- test_new/conftest.py | 15 +- test_new/opt_adam.py | 6 +- test_new/opt_adamw.py | 6 +- test_new/opt_adan.py | 20 +-- test_new/opt_anyadam.py | 16 +- test_new/opt_lion.py | 15 +- test_new/opt_radam.py | 6 +- test_new/opt_ranger.py | 6 +- test_new/opt_sgd.py | 22 +-- test_new/opt_stableadamw.py | 8 +- test_new/runners.py | 3 +- test_new/test_optimizers_unified.py | 2 +- 14 files changed, 282 insertions(+), 289 deletions(-) delete mode 100644 test_new/cases.py diff --git a/test_new/cases.py b/test_new/cases.py deleted file mode 100644 index ef700b7..0000000 --- a/test_new/cases.py +++ /dev/null @@ -1,223 +0,0 @@ -import importlib -import inspect -import warnings -from dataclasses import asdict, dataclass, field, replace -from enum import Enum -from pathlib import Path -from typing import Any - -import torch -from optimi.optimizer import OptimiOptimizer -from torch.optim import Optimizer - - -class TestType(Enum): - correctness = "correctness" - gradient_release = "gradient_release" - accumulation = "accumulation" - - -class DeviceType(Enum): - cpu = "cpu" - gpu = "gpu" - - -class Backend(Enum): - torch = "torch" - triton = "triton" - foreach = "foreach" - - -@dataclass -class Tolerance: - atol: float = 1e-6 - rtol: float = 1e-5 - max_error_rate: float = 5e-4 - equal_nan: bool = False - - -@dataclass -class BaseParams: - lr: float = 1e-3 - weight_decay: float = 0.0 - decouple_wd: bool = False - decouple_lr: bool = False - triton: bool = False - - def with_(self, **overrides: Any) -> "BaseParams": - return replace(self, **overrides) - - def _kwargs_for(self, cls: type | None) -> dict[str, Any]: - if cls is None: - return {} - sig = inspect.signature(cls.__init__) - ok = set(sig.parameters) - {"self"} - values = asdict(self) - if values.get("triton") and "triton" not in ok: - warnings.warn(f"{cls.__name__} does not accept triton; ignoring BaseParams.triton=True.", RuntimeWarning) - return {k: v for k, v in values.items() if k in ok} - - def to_optimi_kwargs(self, cls: type[OptimiOptimizer]) -> dict[str, Any]: - return self._kwargs_for(cls) - - def to_reference_kwargs(self, cls: type[Optimizer]) -> dict[str, Any]: - return self._kwargs_for(cls) - - -@dataclass -class OptTest: - # Identification - name: str # e.g. "adam_base" - - # Classes + params - optimi_class: type[OptimiOptimizer] - optimi_params: BaseParams - reference_class: type[Optimizer] - reference_params: BaseParams | None = None - - # Optional fully decoupled reference for decoupled-lr variant - fully_decoupled_reference: type[Optimizer] | None = None - - # Behavior / constraints - test_decoupled_wd: bool = True - skip_tests: list[TestType] = field(default_factory=list) - any_precision: bool = False - custom_iterations: dict[TestType | tuple[TestType, DeviceType], int] | None = None - custom_tolerances: dict[torch.dtype, Tolerance] | None = None - only_dtypes: list[torch.dtype] | None = None - - def __post_init__(self): - if self.reference_params is None: - self.reference_params = self.optimi_params - if self.custom_tolerances is None: - self.custom_tolerances = {} - # reasonable defaults; override per-case as needed - self.custom_tolerances.setdefault(torch.float32, Tolerance()) - self.custom_tolerances.setdefault(torch.bfloat16, Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01)) - self.custom_tolerances.setdefault(torch.float16, Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01)) - - @property - def optimizer_name(self) -> str: - return self.name.split("_", 1)[0] - - @property - def variant_name(self) -> str: - return self.name.split("_", 1)[1] if "_" in self.name else "base" - - def to_optimi_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: - # Both new BaseParams and legacy test_new.framework.BaseParams expose this - kw = self.optimi_params.to_optimi_kwargs(self.optimi_class) - - # Centralize backend controls so runners don't mutate kwargs later - if backend is not None: - if backend == Backend.triton: - kw["triton"] = True - kw["foreach"] = False - elif backend == Backend.torch: - kw["triton"] = False - kw["foreach"] = False - elif backend == Backend.foreach: - kw["triton"] = False - kw["foreach"] = True - else: - raise ValueError(f"Unknown backend: {backend}") - return kw - - def to_reference_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: - assert self.reference_params is not None - kwargs = self.reference_params.to_reference_kwargs(self.reference_class) - # Centralize fused handling for reference optimizers: when not testing - # Optimi's Triton backend, avoid fused codepaths on the reference side - # to mirror legacy parity expectations. - if backend is not None and backend != Backend.triton: - try: - if "fused" in inspect.signature(self.reference_class.__init__).parameters: - kwargs = {**kwargs, "fused": False} - except (ValueError, TypeError): - pass - return kwargs - - def supports_l2_weight_decay(self) -> bool: - return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters - - -def default_variants(base: OptTest) -> list[OptTest]: - """Generate base + L2 + decoupled variants with minimal boilerplate.""" - out: list[OptTest] = [] - - base_test = OptTest( - name=f"{base.optimizer_name}_base", - optimi_class=base.optimi_class, - optimi_params=base.optimi_params.with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), - reference_class=base.reference_class, - reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), - test_decoupled_wd=base.test_decoupled_wd, - skip_tests=list(base.skip_tests), - any_precision=base.any_precision, - custom_iterations=base.custom_iterations, - custom_tolerances=base.custom_tolerances, - only_dtypes=base.only_dtypes, - fully_decoupled_reference=base.fully_decoupled_reference, - ) - out.append(base_test) - - # L2 weight decay if optimizer supports decouple_wd arg - if "decouple_wd" in inspect.signature(base.optimi_class.__init__).parameters: - out.append( - replace( - base_test, - name=f"{base.optimizer_name}_l2_wd", - optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=False), - reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=False), - ) - ) - - # Decoupled weight decay - if base.test_decoupled_wd: - out.append( - replace( - base_test, - name=f"{base.optimizer_name}_decoupled_wd", - optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=True), - reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=True), - ) - ) - - # Decoupled LR (optionally swap reference class) - ref_cls = base.fully_decoupled_reference or base.reference_class - out.append( - replace( - base_test, - name=f"{base.optimizer_name}_decoupled_lr", - optimi_params=base.optimi_params.with_(weight_decay=1e-5, decouple_lr=True), - reference_class=ref_cls, - reference_params=(base.reference_params or base.optimi_params).with_( - weight_decay=1e-5 if base.fully_decoupled_reference else 0.01, - decouple_lr=True, - ), - ) - ) - return out - - -def discover_tests(root: Path | None = None) -> list[OptTest]: - """ - Discover `opt_*.py` modules in this package. Accept exactly: - - TESTS: list[OptTest] - - BASE: OptTest -> expanded via default_variants(BASE) - """ - if root is None: - root = Path(__file__).parent - cases: list[OptTest] = [] - for f in root.glob("opt_*.py"): - mod = importlib.import_module(f".{f.stem}", package=__package__) - if hasattr(mod, "TESTS"): - cases.extend(getattr(mod, "TESTS")) - elif hasattr(mod, "BASE"): - base = getattr(mod, "BASE") - cases.extend(default_variants(base)) - return cases - - -def optimizer_names() -> list[str]: - return sorted({c.optimizer_name for c in discover_tests()}) diff --git a/test_new/config.py b/test_new/config.py index ee0d2e3..8fd14db 100644 --- a/test_new/config.py +++ b/test_new/config.py @@ -1,10 +1,41 @@ from __future__ import annotations -from dataclasses import dataclass, field +import importlib +import inspect +import warnings +from dataclasses import asdict, dataclass, field, replace +from enum import Enum +from pathlib import Path +from typing import Any import torch +from optimi.optimizer import OptimiOptimizer +from torch.optim import Optimizer -from .cases import Tolerance + +class TestType(Enum): + correctness = "correctness" + gradient_release = "gradient_release" + accumulation = "accumulation" + + +class DeviceType(Enum): + cpu = "cpu" + gpu = "gpu" + + +class Backend(Enum): + torch = "torch" + triton = "triton" + foreach = "foreach" + + +@dataclass +class Tolerance: + atol: float = 1e-6 + rtol: float = 1e-5 + max_error_rate: float = 5e-4 + equal_nan: bool = False @dataclass(frozen=True) @@ -59,3 +90,191 @@ class TestDefaults: # Single place to tweak numbers used by runners DEFAULTS = TestDefaults() + + +@dataclass +class BaseParams: + lr: float = 1e-3 + weight_decay: float = 0.0 + decouple_wd: bool = False + decouple_lr: bool = False + triton: bool = False + + def with_(self, **overrides: Any) -> "BaseParams": + return replace(self, **overrides) + + def _kwargs_for(self, cls: type | None) -> dict[str, Any]: + if cls is None: + return {} + sig = inspect.signature(cls.__init__) + ok = set(sig.parameters) - {"self"} + values = asdict(self) + if values.get("triton") and "triton" not in ok: + warnings.warn(f"{cls.__name__} does not accept triton; ignoring BaseParams.triton=True.", RuntimeWarning) + return {k: v for k, v in values.items() if k in ok} + + def to_optimi_kwargs(self, cls: type[OptimiOptimizer]) -> dict[str, Any]: + return self._kwargs_for(cls) + + def to_reference_kwargs(self, cls: type[Optimizer]) -> dict[str, Any]: + return self._kwargs_for(cls) + + +@dataclass +class OptTest: + # Identification + name: str # e.g. "adam_base" + + # Classes + params + optimi_class: type[OptimiOptimizer] + optimi_params: BaseParams + reference_class: type[Optimizer] + reference_params: BaseParams | None = None + + # Optional fully decoupled reference for decoupled-lr variant + fully_decoupled_reference: type[Optimizer] | None = None + + # Behavior / constraints + skip_tests: list[TestType] = field(default_factory=list) + any_precision: bool = False + test_decoupled_wd: bool = True + custom_iterations: dict[TestType | tuple[TestType, DeviceType], int] | None = None + custom_tolerances: dict[torch.dtype, Tolerance] | None = None + only_dtypes: list[torch.dtype] | None = None + + def __post_init__(self): + if self.reference_params is None: + self.reference_params = self.optimi_params + if self.custom_tolerances is None: + self.custom_tolerances = {} + # reasonable defaults; override per-case as needed + self.custom_tolerances.setdefault(torch.float32, Tolerance()) + self.custom_tolerances.setdefault(torch.bfloat16, Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01)) + self.custom_tolerances.setdefault(torch.float16, Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01)) + + @property + def optimizer_name(self) -> str: + return self.name.split("_", 1)[0] + + @property + def variant_name(self) -> str: + return self.name.split("_", 1)[1] if "_" in self.name else "base" + + def to_optimi_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: + kw = self.optimi_params.to_optimi_kwargs(self.optimi_class) + + # Centralize backend controls so runners don't mutate kwargs later + if backend is not None: + if backend == Backend.triton: + kw["triton"] = True + kw["foreach"] = False + elif backend == Backend.torch: + kw["triton"] = False + kw["foreach"] = False + elif backend == Backend.foreach: + kw["triton"] = False + kw["foreach"] = True + else: + raise ValueError(f"Unknown backend: {backend}") + return kw + + def to_reference_kwargs(self, backend: Backend | None = None) -> dict[str, Any]: + assert self.reference_params is not None + kwargs = self.reference_params.to_reference_kwargs(self.reference_class) + # Centralize fused handling for reference optimizers: when not testing + # Optimi's Triton backend, avoid fused codepaths on the reference side + # to mirror legacy parity expectations. + if backend is not None and backend != Backend.triton: + try: + if "fused" in inspect.signature(self.reference_class.__init__).parameters: + kwargs = {**kwargs, "fused": False} + except (ValueError, TypeError): + pass + return kwargs + + def supports_l2_weight_decay(self) -> bool: + return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters + + +def default_variants(base: OptTest) -> list[OptTest]: + """Generate base + L2 + decoupled variants with minimal boilerplate.""" + out: list[OptTest] = [] + + base_test = OptTest( + name=f"{base.optimizer_name}_base", + optimi_class=base.optimi_class, + optimi_params=base.optimi_params.with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), + reference_class=base.reference_class, + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.0, decouple_wd=False, decouple_lr=False), + test_decoupled_wd=base.test_decoupled_wd, + skip_tests=list(base.skip_tests), + any_precision=base.any_precision, + custom_iterations=base.custom_iterations, + custom_tolerances=base.custom_tolerances, + only_dtypes=base.only_dtypes, + fully_decoupled_reference=base.fully_decoupled_reference, + ) + out.append(base_test) + + optimi_params = inspect.signature(base.optimi_class.__init__).parameters + + # L2 weight decay if optimizer supports decouple_wd arg + if "decouple_wd" in optimi_params: + out.append( + replace( + base_test, + name=f"{base.optimizer_name}_l2_wd", + optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=False), + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=False), + ) + ) + + # Decoupled weight decay + if base.test_decoupled_wd and "decouple_lr" in optimi_params: + out.append( + replace( + base_test, + name=f"{base.optimizer_name}_decoupled_wd", + optimi_params=base.optimi_params.with_(weight_decay=0.01, decouple_wd=True), + reference_params=(base.reference_params or base.optimi_params).with_(weight_decay=0.01, decouple_wd=True), + ) + ) + + # Decoupled LR (optionally swap reference class) + ref_cls = base.fully_decoupled_reference or base.reference_class + out.append( + replace( + base_test, + name=f"{base.optimizer_name}_decoupled_lr", + optimi_params=base.optimi_params.with_(weight_decay=1e-5, decouple_lr=True), + reference_class=ref_cls, + reference_params=(base.reference_params or base.optimi_params).with_( + weight_decay=1e-5 if base.fully_decoupled_reference else 0.01, + decouple_lr=True, + ), + ) + ) + return out + + +def discover_tests(root: Path | None = None) -> list[OptTest]: + """ + Discover `opt_*.py` modules in this package. Accept exactly: + - TESTS: list[OptTest] + - BASE: OptTest -> expanded via default_variants(BASE) + """ + if root is None: + root = Path(__file__).parent + cases: list[OptTest] = [] + for f in root.glob("opt_*.py"): + mod = importlib.import_module(f".{f.stem}", package=__package__) + if hasattr(mod, "TESTS"): + cases.extend(getattr(mod, "TESTS")) + elif hasattr(mod, "BASE"): + base = getattr(mod, "BASE") + cases.extend(default_variants(base)) + return cases + + +def optimizer_names() -> list[str]: + return sorted({c.optimizer_name for c in discover_tests()}) diff --git a/test_new/conftest.py b/test_new/conftest.py index 323fdec..486378d 100644 --- a/test_new/conftest.py +++ b/test_new/conftest.py @@ -7,12 +7,11 @@ import pytest import torch -from .cases import optimizer_names +from .config import optimizer_names def pytest_configure(config): - """Configure pytest with custom marks for optimizer testing.""" - + "Configure pytest with custom marks for optimizer testing." # Register device marks config.addinivalue_line("markers", "cpu: mark test to run on CPU") config.addinivalue_line("markers", "gpu: mark test to run on GPU") @@ -31,14 +30,14 @@ def pytest_configure(config): def pytest_addoption(parser): - """Add command-line option to specify a single GPU""" + "Add command-line option to specify a single GPU." parser.addoption("--gpu-id", action="store", type=int, default=None, help="Specify a single GPU to use (e.g. --gpu-id=0)") @pytest.fixture() def gpu_device(worker_id, request): - """Map xdist workers to available GPU devices in a round-robin fashion, - supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. + """Map xdist workers to available GPU devices in a round-robin fashion, supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. + Use a single specified GPU if --gpu-id is provided""" # Check if specific GPU was requested @@ -55,9 +54,7 @@ def gpu_device(worker_id, request): backend = "mps" device_count = 0 else: - # Fallback to cuda for compatibility - backend = "cuda" - device_count = 0 + raise RuntimeError("No GPU backend available") if specific_gpu is not None: return torch.device(f"{backend}:{specific_gpu}") diff --git a/test_new/opt_adam.py b/test_new/opt_adam.py index 99afa15..e662184 100644 --- a/test_new/opt_adam.py +++ b/test_new/opt_adam.py @@ -5,7 +5,7 @@ import optimi import torch -from .cases import BaseParams, OptTest +from .config import BaseParams, OptTest @dataclass @@ -19,8 +19,8 @@ class AdamParams(BaseParams): BASE = OptTest( name="adam", optimi_class=optimi.Adam, - optimi_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.0), + optimi_params=AdamParams(), reference_class=torch.optim.Adam, - reference_params=AdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.0), + reference_params=AdamParams(), test_decoupled_wd=False, ) diff --git a/test_new/opt_adamw.py b/test_new/opt_adamw.py index f0cbc5a..9ea41d3 100644 --- a/test_new/opt_adamw.py +++ b/test_new/opt_adamw.py @@ -6,7 +6,7 @@ import torch from tests import reference -from .cases import BaseParams, OptTest +from .config import BaseParams, OptTest @dataclass @@ -19,8 +19,8 @@ class AdamWParams(BaseParams): BASE = OptTest( name="adamw", optimi_class=optimi.AdamW, - optimi_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + optimi_params=AdamWParams(), reference_class=torch.optim.AdamW, - reference_params=AdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_params=AdamWParams(), fully_decoupled_reference=reference.DecoupledAdamW, ) diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py index 48171a5..39f5d78 100644 --- a/test_new/opt_adan.py +++ b/test_new/opt_adan.py @@ -6,13 +6,13 @@ import optimi from tests import reference -from .cases import BaseParams, OptTest +from .config import BaseParams, OptTest @dataclass class AdanParams(BaseParams): betas: tuple[float, float, float] = (0.98, 0.92, 0.99) - eps: float = 1e-8 + eps: float = 1e-6 weight_decouple: bool = False # For adam_wd variant (maps to no_prox in reference) adam_wd: bool = False # For optimi optimizer @@ -28,29 +28,29 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: OptTest( name="adan_base", optimi_class=optimi.Adan, - optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=0), + optimi_params=AdanParams(), reference_class=reference.Adan, - reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6), + reference_params=AdanParams(), ), OptTest( name="adan_weight_decay", optimi_class=optimi.Adan, - optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + optimi_params=AdanParams(weight_decay=2e-2), reference_class=reference.Adan, - reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + reference_params=AdanParams(weight_decay=2e-2), ), OptTest( name="adan_adam_wd", optimi_class=optimi.Adan, - optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, adam_wd=True), + optimi_params=AdanParams(weight_decay=2e-2, adam_wd=True), reference_class=reference.Adan, - reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, weight_decouple=True), + reference_params=AdanParams(weight_decay=2e-2, weight_decouple=True), ), OptTest( name="adan_decoupled_lr", optimi_class=optimi.Adan, - optimi_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-5, decouple_lr=True), + optimi_params=AdanParams(weight_decay=2e-5, decouple_lr=True), reference_class=reference.Adan, - reference_params=AdanParams(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2), + reference_params=AdanParams(weight_decay=2e-2), ), ] diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py index 8301f03..adeaaa2 100644 --- a/test_new/opt_anyadam.py +++ b/test_new/opt_anyadam.py @@ -6,13 +6,13 @@ import torch from tests.reference import AnyPrecisionAdamW -from .cases import BaseParams, OptTest, Tolerance +from .config import BaseParams, OptTest, Tolerance @dataclass class AnyAdamParams(BaseParams): betas: tuple[float, float] = (0.9, 0.999) - eps: float = 1e-8 + eps: float = 1e-6 kahan_sum: bool = False use_kahan_summation: bool = False @@ -31,9 +31,9 @@ def to_reference_kwargs(self, reference_class: type) -> dict: OptTest( name="anyadam_kahan", optimi_class=optimi.Adam, - optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, kahan_sum=True), + optimi_params=AnyAdamParams(betas=(0.9, 0.99), kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, use_kahan_summation=True), + reference_params=AnyAdamParams(betas=(0.9, 0.99), use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, @@ -41,9 +41,9 @@ def to_reference_kwargs(self, reference_class: type) -> dict: OptTest( name="anyadam_kahan_wd", optimi_class=optimi.AdamW, - optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.01, kahan_sum=True), + optimi_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=0.01, kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0.01, use_kahan_summation=True), + reference_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=0.01, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, custom_tolerances={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01, equal_nan=False)}, @@ -51,9 +51,9 @@ def to_reference_kwargs(self, reference_class: type) -> dict: OptTest( name="anyadam_kahan_decoupled_lr", optimi_class=optimi.AdamW, - optimi_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5, decouple_lr=True, kahan_sum=True), + optimi_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True, kahan_sum=True), reference_class=AnyPrecisionAdamW, - reference_params=AnyAdamParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2, use_kahan_summation=True), + reference_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=1e-2, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, diff --git a/test_new/opt_lion.py b/test_new/opt_lion.py index 63c1f84..93cf97a 100644 --- a/test_new/opt_lion.py +++ b/test_new/opt_lion.py @@ -5,11 +5,12 @@ import optimi from tests.reference import lion as reference_lion -from .cases import BaseParams, OptTest +from .config import BaseParams, OptTest @dataclass class LionParams(BaseParams): + lr: float = 1e-4 betas: tuple[float, float] = (0.9, 0.99) @@ -17,22 +18,22 @@ class LionParams(BaseParams): OptTest( name="lion_base", optimi_class=optimi.Lion, - optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), + optimi_params=LionParams(), reference_class=reference_lion.Lion, - reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0), + reference_params=LionParams(), ), OptTest( name="lion_decoupled_wd", optimi_class=optimi.Lion, - optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1, decouple_wd=True), + optimi_params=LionParams(weight_decay=0.1, decouple_wd=True), reference_class=reference_lion.Lion, - reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), + reference_params=LionParams(weight_decay=0.1), ), OptTest( name="lion_decoupled_lr", optimi_class=optimi.Lion, - optimi_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True), + optimi_params=LionParams(weight_decay=1e-5, decouple_lr=True), reference_class=reference_lion.Lion, - reference_params=LionParams(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1), + reference_params=LionParams(weight_decay=0.1), ), ] diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index ed84ff3..ed4a026 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -6,7 +6,7 @@ import optimi import torch -from .cases import BaseParams, OptTest, Tolerance +from .config import BaseParams, OptTest, Tolerance @dataclass @@ -23,9 +23,9 @@ def __post_init__(self): BASE = OptTest( name="radam", optimi_class=optimi.RAdam, - optimi_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), + optimi_params=RAdamParams(), reference_class=torch.optim.RAdam, - reference_params=RAdamParams(lr=1e-3, betas=(0.9, 0.99), weight_decay=0), + reference_params=RAdamParams(), custom_tolerances={torch.float32: Tolerance(max_error_rate=0.001)}, test_decoupled_wd="decoupled_weight_decay" in inspect.signature(torch.optim.RAdam.__init__).parameters, ) diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py index b831704..0ffe25c 100644 --- a/test_new/opt_ranger.py +++ b/test_new/opt_ranger.py @@ -5,7 +5,7 @@ import optimi from tests import reference -from .cases import BaseParams, OptTest, TestType +from .config import BaseParams, OptTest, TestType @dataclass @@ -20,9 +20,9 @@ class RangerParams(BaseParams): OptTest( name="ranger_base", optimi_class=optimi.Ranger, - optimi_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), + optimi_params=RangerParams(), reference_class=reference.Ranger, - reference_params=RangerParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0), + reference_params=RangerParams(), # Match legacy longer gradient-release coverage due to Lookahead cadence. custom_iterations={TestType.gradient_release: 160}, ) diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py index 642b31d..4528fff 100644 --- a/test_new/opt_sgd.py +++ b/test_new/opt_sgd.py @@ -7,7 +7,7 @@ import torch from tests import reference -from .cases import BaseParams, OptTest, TestType +from .config import BaseParams, OptTest, TestType @dataclass @@ -29,38 +29,38 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: OptTest( name="sgd_base", optimi_class=optimi.SGD, - optimi_params=SGDParams(lr=1e-3, momentum=0, dampening=False, weight_decay=0), + optimi_params=SGDParams(), reference_class=torch.optim.SGD, - reference_params=SGDParams(lr=1e-3, momentum=0, dampening=0, weight_decay=0), + reference_params=SGDParams(), skip_tests=[TestType.accumulation], ), OptTest( name="sgd_momentum", optimi_class=optimi.SGD, - optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=0), + optimi_params=SGDParams(momentum=0.9), reference_class=torch.optim.SGD, - reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=0), + reference_params=SGDParams(momentum=0.9), ), OptTest( name="sgd_dampening", optimi_class=optimi.SGD, - optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, weight_decay=0, torch_init=True), + optimi_params=SGDParams(momentum=0.9, dampening=True, torch_init=True), reference_class=torch.optim.SGD, - reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=0), + reference_params=SGDParams(momentum=0.9, dampening=0.9), ), OptTest( name="sgd_weight_decay", optimi_class=optimi.SGD, - optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False), + optimi_params=SGDParams(momentum=0.9, weight_decay=1e-2), reference_class=torch.optim.SGD, - reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2), + reference_params=SGDParams(momentum=0.9, weight_decay=1e-2), skip_tests=[TestType.accumulation], ), OptTest( name="sgd_decoupled_lr", optimi_class=optimi.SGD, - optimi_params=SGDParams(lr=1e-3, momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True), + optimi_params=SGDParams(momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True), reference_class=reference.DecoupledSGDW, - reference_params=SGDParams(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=1e-5), + reference_params=SGDParams(momentum=0.9, dampening=0.9, weight_decay=1e-5), ), ] diff --git a/test_new/opt_stableadamw.py b/test_new/opt_stableadamw.py index 709ea2f..69f1507 100644 --- a/test_new/opt_stableadamw.py +++ b/test_new/opt_stableadamw.py @@ -5,19 +5,19 @@ import optimi from tests import reference -from .cases import BaseParams, OptTest +from .config import BaseParams, OptTest @dataclass class StableAdamWParams(BaseParams): betas: tuple[float, float] = (0.9, 0.99) - eps: float = 1e-8 + eps: float = 1e-6 BASE = OptTest( name="stableadamw", optimi_class=optimi.StableAdamW, - optimi_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + optimi_params=StableAdamWParams(), reference_class=reference.StableAdamWUnfused, - reference_params=StableAdamWParams(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0), + reference_params=StableAdamWParams(), ) diff --git a/test_new/runners.py b/test_new/runners.py index a5b1f33..b171f89 100644 --- a/test_new/runners.py +++ b/test_new/runners.py @@ -7,8 +7,7 @@ from optimi import prepare_for_gradient_release, remove_gradient_release from torch import Tensor -from .cases import Backend, DeviceType, OptTest, TestType, Tolerance -from .config import DEFAULTS +from .config import DEFAULTS, Backend, DeviceType, OptTest, TestType, Tolerance def _device_type(device: torch.device) -> DeviceType: diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index de4d848..07cf04d 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -2,7 +2,7 @@ import torch from optimi.utils import MIN_TORCH_2_6 -from .cases import Backend, DeviceType, OptTest, TestType, discover_tests +from .config import Backend, DeviceType, OptTest, TestType, discover_tests from .runners import run_accumulation, run_correctness, run_gradient_release DEVICE_PARAMS = [ From f83f106f2a3e48bce7727ad46a9b30c25ed2170d Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Dec 2025 21:12:09 -0600 Subject: [PATCH 08/16] rename --- test_new/runners.py | 62 ++++++++++++++--------------- test_new/test_optimizers_unified.py | 37 ++++++++--------- 2 files changed, 50 insertions(+), 49 deletions(-) diff --git a/test_new/runners.py b/test_new/runners.py index b171f89..c2dde32 100644 --- a/test_new/runners.py +++ b/test_new/runners.py @@ -15,18 +15,18 @@ def _device_type(device: torch.device) -> DeviceType: def _get_iterations( - case: OptTest, + opttest: OptTest, test_type: TestType, default: int, device: torch.device | None = None, ) -> int: - if not case.custom_iterations: + if not opttest.custom_iterations: return default if device is not None: key = (test_type, _device_type(device)) - if key in case.custom_iterations: - return case.custom_iterations[key] - return case.custom_iterations.get(test_type, default) + if key in opttest.custom_iterations: + return opttest.custom_iterations[key] + return opttest.custom_iterations.get(test_type, default) def assert_most_approx_close( @@ -71,7 +71,7 @@ def forward(self, x: Tensor) -> Tensor: def run_correctness( - case: OptTest, + opttest: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, @@ -79,11 +79,11 @@ def run_correctness( ) -> None: # Iterations and tolerance default_iters = DEFAULTS.correctness.cpu_iterations if device.type == "cpu" else DEFAULTS.correctness.gpu_iterations - iterations = _get_iterations(case, TestType.correctness, default_iters, device=device) - # Special-case: Adan bf16 on GPU - if device.type != "cpu" and dtype == torch.bfloat16 and case.optimizer_name == "adan": + iterations = _get_iterations(opttest, TestType.correctness, default_iters, device=device) + # Special-opttest: Adan bf16 on GPU + if device.type != "cpu" and dtype == torch.bfloat16 and opttest.optimizer_name == "adan": iterations = DEFAULTS.correctness.adan_bf16_gpu_iterations - tolerance = case.custom_tolerances[dtype] + tolerance = opttest.custom_tolerances[dtype] # Dims, batch, errors if dims is not None: dim1, dim2 = dims @@ -100,16 +100,16 @@ def run_correctness( m2.load_state_dict(m1.state_dict()) # Convert parameters to float for non-any_precision - if not case.any_precision and dtype != torch.float32: + if not opttest.any_precision and dtype != torch.float32: for p in m1.parameters(): p.data = p.data.float() # Optimizers - reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs(backend) - optimi_kwargs = case.to_optimi_kwargs(backend) + reference_class = opttest.reference_class + reference_kwargs = opttest.to_reference_kwargs(backend) + optimi_kwargs = opttest.to_optimi_kwargs(backend) reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) - optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) buffer = io.BytesIO() @@ -119,7 +119,7 @@ def run_correctness( target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) target2 = target1.detach().clone() - if not case.any_precision and dtype != torch.float32: + if not opttest.any_precision and dtype != torch.float32: input1 = input1.float() target1 = target1.float() @@ -159,7 +159,7 @@ def run_correctness( torch.save(optimi_optimizer.state_dict(), buffer) buffer.seek(0) ckpt = torch.load(buffer, weights_only=True) - optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) optimi_optimizer.load_state_dict(ckpt) buffer.seek(0) buffer.truncate(0) @@ -185,7 +185,7 @@ def run_correctness( def run_gradient_release( - case: OptTest, + opttest: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, @@ -196,10 +196,10 @@ def optimizer_hook(parameter) -> None: torch_optimizers[parameter].zero_grad() # Iterations - iterations = _get_iterations(case, TestType.gradient_release, DEFAULTS.gradient_release.iterations, device=device) + iterations = _get_iterations(opttest, TestType.gradient_release, DEFAULTS.gradient_release.iterations, device=device) - # Tolerances: merge baseline with per-case - tol = case.custom_tolerances[dtype] + # Tolerances: merge baseline with per-opttest + tol = opttest.custom_tolerances[dtype] baseline = DEFAULTS.gradient_release.baseline_tolerance.get(dtype, tol) tolerance = Tolerance( rtol=max(tol.rtol, baseline.rtol), @@ -224,9 +224,9 @@ def optimizer_hook(parameter) -> None: m2.load_state_dict(m1.state_dict()) m3.load_state_dict(m1.state_dict()) - reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs(backend) - optimi_kwargs = case.to_optimi_kwargs(backend) + reference_class = opttest.reference_class + reference_kwargs = opttest.to_reference_kwargs(backend) + optimi_kwargs = opttest.to_optimi_kwargs(backend) regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} @@ -235,7 +235,7 @@ def optimizer_hook(parameter) -> None: pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) optimi_kwargs["gradient_release"] = True - optimi_optimizer = case.optimi_class(m3.parameters(), **optimi_kwargs) + optimi_optimizer = opttest.optimi_class(m3.parameters(), **optimi_kwargs) prepare_for_gradient_release(m3, optimi_optimizer) for _ in range(iterations): @@ -309,14 +309,14 @@ def optimizer_hook(parameter) -> None: def run_accumulation( - case: OptTest, + opttest: OptTest, device: torch.device, dtype: torch.dtype, backend: Backend, dims: tuple[int, int] | None = None, ) -> None: # Iterations - iterations = _get_iterations(case, TestType.accumulation, DEFAULTS.accumulation.iterations, device=device) + iterations = _get_iterations(opttest, TestType.accumulation, DEFAULTS.accumulation.iterations, device=device) # Dims and batch size if dims is not None: @@ -334,13 +334,13 @@ def run_accumulation( m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation m2.load_state_dict(m1.state_dict()) - reference_class = case.reference_class - reference_kwargs = case.to_reference_kwargs(backend) - optimi_kwargs = case.to_optimi_kwargs(backend) + reference_class = opttest.reference_class + reference_kwargs = opttest.to_reference_kwargs(backend) + optimi_kwargs = opttest.to_optimi_kwargs(backend) regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) optimi_kwargs["gradient_release"] = True - optimi_optimizer = case.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) prepare_for_gradient_release(m2, optimi_optimizer) gradient_accumulation_steps = DEFAULTS.accumulation.gradient_accumulation_steps diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index 07cf04d..aa1b4e2 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -15,6 +15,7 @@ ] BACKEND_PARAMS = [ pytest.param(Backend.torch, marks=pytest.mark.torch, id=Backend.torch.value), + pytest.param(Backend.foreach, marks=pytest.mark.foreach, id=Backend.foreach.value), pytest.param(Backend.triton, marks=pytest.mark.triton, id=Backend.triton.value), ] @@ -39,13 +40,13 @@ ] -def _should_skip(test_type: TestType, case: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: - # Explicit per-case skip - if test_type in set(case.skip_tests): +def _should_skip(test_type: TestType, opttest: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: + # Explicit per-opttest skip + if test_type in set(opttest.skip_tests): return True # Respect per-test dtype constraints if provided - if case.only_dtypes and dtype not in case.only_dtypes: + if opttest.only_dtypes and dtype not in opttest.only_dtypes: return True # Skip triton on CPU @@ -85,45 +86,45 @@ def _should_skip(test_type: TestType, case: OptTest, device_type: DeviceType, dt ): return True - # Skip bfloat16 on CPU for most optimizers; allow anyadam exception via case.any_precision - if device_type == DeviceType.cpu and dtype == torch.bfloat16 and not case.any_precision: + # Skip bfloat16 on CPU for most optimizers; allow anyadam exception via opttest.any_precision + if device_type == DeviceType.cpu and dtype == torch.bfloat16 and not opttest.any_precision: return True return False -@pytest.mark.parametrize("case", OPTIMIZERS) +@pytest.mark.parametrize("opttest", OPTIMIZERS) @pytest.mark.parametrize("device_type", DEVICE_PARAMS) @pytest.mark.parametrize("dtype", DTYPE_PARAMS) @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims_spec", FULL_DIMS) -def test_correctness(case, device_type, dtype, backend, dims_spec, gpu_device): - if _should_skip(TestType.correctness, case, device_type, dtype, backend): +def test_correctness(opttest, device_type, dtype, backend, dims_spec, gpu_device): + if _should_skip(TestType.correctness, opttest, device_type, dtype, backend): pytest.skip() dims_device, dims = dims_spec if dims_device != device_type: pytest.skip() device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") - run_correctness(case, device, dtype, backend, dims=dims) + run_correctness(opttest, device, dtype, backend, dims=dims) -@pytest.mark.parametrize("case", OPTIMIZERS) +@pytest.mark.parametrize("opttest", OPTIMIZERS) @pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims", SUBSET_DIMS) -def test_gradient_release(case, device_type, dtype, backend, dims, gpu_device): - if _should_skip(TestType.gradient_release, case, device_type, dtype, backend): +def test_gradient_release(opttest, device_type, dtype, backend, dims, gpu_device): + if _should_skip(TestType.gradient_release, opttest, device_type, dtype, backend): pytest.skip() - run_gradient_release(case, torch.device(gpu_device), dtype, backend, dims=dims) + run_gradient_release(opttest, torch.device(gpu_device), dtype, backend, dims=dims) -@pytest.mark.parametrize("case", OPTIMIZERS) +@pytest.mark.parametrize("opttest", OPTIMIZERS) @pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) @pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) @pytest.mark.parametrize("backend", BACKEND_PARAMS) @pytest.mark.parametrize("dims", SUBSET_DIMS) -def test_accumulation(case, device_type, dtype, backend, dims, gpu_device): - if _should_skip(TestType.accumulation, case, device_type, dtype, backend): +def test_accumulation(opttest, device_type, dtype, backend, dims, gpu_device): + if _should_skip(TestType.accumulation, opttest, device_type, dtype, backend): pytest.skip() - run_accumulation(case, torch.device(gpu_device), dtype, backend, dims=dims) + run_accumulation(opttest, torch.device(gpu_device), dtype, backend, dims=dims) From e4364e9bc1ee086e58c93fbab39fb8ff648eac83 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 23 Dec 2025 21:39:18 -0600 Subject: [PATCH 09/16] skip tests while collecting --- test_new/conftest.py | 5 +- test_new/test_optimizers_unified.py | 86 +++++++++++++++++++++-------- 2 files changed, 65 insertions(+), 26 deletions(-) diff --git a/test_new/conftest.py b/test_new/conftest.py index 486378d..eda353d 100644 --- a/test_new/conftest.py +++ b/test_new/conftest.py @@ -22,6 +22,7 @@ def pytest_configure(config): # Register backend marks config.addinivalue_line("markers", "torch: mark test to run with torch backend") + config.addinivalue_line("markers", "foreach: mark test to run with foreach backend") config.addinivalue_line("markers", "triton: mark test to run with triton backend") # Per-optimizer marks (e.g., -m adam, -m sgd) @@ -36,8 +37,8 @@ def pytest_addoption(parser): @pytest.fixture() def gpu_device(worker_id, request): - """Map xdist workers to available GPU devices in a round-robin fashion, supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. - + """Map xdist workers to available GPU devices in a round-robin fashion, + supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. Use a single specified GPU if --gpu-id is provided""" # Check if specific GPU was requested diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index aa1b4e2..9f08d9f 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -1,5 +1,6 @@ import pytest import torch +from _pytest.mark.structures import ParameterSet from optimi.utils import MIN_TORCH_2_6 from .config import Backend, DeviceType, OptTest, TestType, discover_tests @@ -93,38 +94,75 @@ def _should_skip(test_type: TestType, opttest: OptTest, device_type: DeviceType, return False -@pytest.mark.parametrize("opttest", OPTIMIZERS) -@pytest.mark.parametrize("device_type", DEVICE_PARAMS) -@pytest.mark.parametrize("dtype", DTYPE_PARAMS) -@pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims_spec", FULL_DIMS) +def _param_value(param: ParameterSet) -> object: + return param.values[0] + + +def _param_id(param: ParameterSet) -> str: + return param.id or str(param.values[0]) + + +def _build_params(test_type: TestType) -> list[ParameterSet]: + if test_type == TestType.correctness: + device_params = DEVICE_PARAMS + dtype_params = DTYPE_PARAMS + dims_params = FULL_DIMS + else: + device_params = [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)] + dtype_params = [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")] + dims_params = SUBSET_DIMS + + params: list[ParameterSet] = [] + for opt_param in OPTIMIZERS: + for device_param in device_params: + for dtype_param in dtype_params: + for backend_param in BACKEND_PARAMS: + for dims_param in dims_params: + if test_type == TestType.correctness and _param_value(dims_param)[0] != _param_value(device_param): + continue + if _should_skip( + test_type, + _param_value(opt_param), + _param_value(device_param), + _param_value(dtype_param), + _param_value(backend_param), + ): + continue + param_id = "-".join( + [ + _param_id(dims_param), + _param_id(backend_param), + _param_id(dtype_param), + _param_id(device_param), + _param_id(opt_param), + ] + ) + params.append( + pytest.param( + _param_value(opt_param), + _param_value(device_param), + _param_value(dtype_param), + _param_value(backend_param), + _param_value(dims_param), + id=param_id, + marks=list(opt_param.marks + device_param.marks + dtype_param.marks + backend_param.marks), + ) + ) + return params + + +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(TestType.correctness)) def test_correctness(opttest, device_type, dtype, backend, dims_spec, gpu_device): - if _should_skip(TestType.correctness, opttest, device_type, dtype, backend): - pytest.skip() - dims_device, dims = dims_spec - if dims_device != device_type: - pytest.skip() + _, dims = dims_spec device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") run_correctness(opttest, device, dtype, backend, dims=dims) -@pytest.mark.parametrize("opttest", OPTIMIZERS) -@pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) -@pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) -@pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims", SUBSET_DIMS) +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(TestType.gradient_release)) def test_gradient_release(opttest, device_type, dtype, backend, dims, gpu_device): - if _should_skip(TestType.gradient_release, opttest, device_type, dtype, backend): - pytest.skip() run_gradient_release(opttest, torch.device(gpu_device), dtype, backend, dims=dims) -@pytest.mark.parametrize("opttest", OPTIMIZERS) -@pytest.mark.parametrize("device_type", [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)]) -@pytest.mark.parametrize("dtype", [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")]) -@pytest.mark.parametrize("backend", BACKEND_PARAMS) -@pytest.mark.parametrize("dims", SUBSET_DIMS) +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(TestType.accumulation)) def test_accumulation(opttest, device_type, dtype, backend, dims, gpu_device): - if _should_skip(TestType.accumulation, opttest, device_type, dtype, backend): - pytest.skip() run_accumulation(opttest, torch.device(gpu_device), dtype, backend, dims=dims) From b92af3b67ed9f8ff6dafb1d43324a58d51ef1304 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 11:17:49 -0600 Subject: [PATCH 10/16] unify test runner, use a TestSpec for defaults overridable per optimizer --- test_new/config.py | 117 +++++---- test_new/opt_adan.py | 7 +- test_new/opt_anyadam.py | 20 +- test_new/opt_radam.py | 4 +- test_new/opt_ranger.py | 4 +- test_new/opt_sgd.py | 6 +- test_new/runner.py | 332 ++++++++++++++++++++++++ test_new/runners.py | 376 ---------------------------- test_new/test_optimizers_unified.py | 32 +-- 9 files changed, 452 insertions(+), 446 deletions(-) create mode 100644 test_new/runner.py delete mode 100644 test_new/runners.py diff --git a/test_new/config.py b/test_new/config.py index 8fd14db..116531f 100644 --- a/test_new/config.py +++ b/test_new/config.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer -class TestType(Enum): - correctness = "correctness" +class OptTestType(Enum): + default = "default" gradient_release = "gradient_release" accumulation = "accumulation" @@ -38,31 +38,31 @@ class Tolerance: equal_nan: bool = False -@dataclass(frozen=True) -class CorrectnessDefaults: - cpu_iterations: int = 20 - gpu_iterations: int = 40 - # Special-case: Adan in bf16 on GPU is noisier; align to 20 - adan_bf16_gpu_iterations: int = 20 +@dataclass() +class CorrectnessSpec: + iterations_cpu: int = 20 + iterations_gpu: int = 40 + batch_cpu: int = 1 + batch_gpu: int = 32 + max_error_cpu: int = 2 + max_error_gpu: int = 5 - cpu_dims: tuple[int, int] = (64, 128) - gpu_dims: tuple[int, int] = (256, 512) - - cpu_batch_size: int = 1 - gpu_batch_size: int = 32 - - cpu_max_error_count: int = 2 - gpu_max_error_count: int = 5 + tolerance: dict[torch.dtype, Tolerance] = field( + default_factory=lambda: { + torch.float32: Tolerance(atol=1e-6, rtol=1e-5, max_error_rate=5e-4), + torch.bfloat16: Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01), + torch.float16: Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01), + } + ) -@dataclass(frozen=True) -class GradientReleaseDefaults: +@dataclass() +class GradientReleaseSpec: iterations: int = 40 - dims: tuple[int, int] = (128, 256) - batch_size: int = 32 + batch: int = 32 max_error_count: int = 12 # more lenient for noisy updates - baseline_tolerance: dict[torch.dtype, Tolerance] = field( + tolerance: dict[torch.dtype, Tolerance] = field( default_factory=lambda: { torch.float32: Tolerance(atol=1e-6, rtol=1e-5, max_error_rate=5e-4), torch.bfloat16: Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01), @@ -71,25 +71,60 @@ class GradientReleaseDefaults: ) -@dataclass(frozen=True) -class AccumulationDefaults: +@dataclass() +class AccumulationSpec: iterations: int = 40 - dims: tuple[int, int] = (128, 256) - batch_size: int = 32 - tolerance: Tolerance = field(default_factory=lambda: Tolerance(rtol=1e-2, atol=1e-2)) + batch: int = 32 max_error_rate: float = 0.035 gradient_accumulation_steps: int = 4 - -@dataclass(frozen=True) -class TestDefaults: - correctness: CorrectnessDefaults = CorrectnessDefaults() - gradient_release: GradientReleaseDefaults = GradientReleaseDefaults() - accumulation: AccumulationDefaults = AccumulationDefaults() + tolerance: dict[torch.dtype, Tolerance] = field( + default_factory=lambda: { + torch.float32: Tolerance(rtol=1e-2, atol=1e-2), + torch.bfloat16: Tolerance(rtol=1e-2, atol=1e-2), + torch.float16: Tolerance(rtol=1e-2, atol=1e-2), + } + ) -# Single place to tweak numbers used by runners -DEFAULTS = TestDefaults() +@dataclass() +class TestSpec: + default: CorrectnessSpec = CorrectnessSpec() + gradient_release: GradientReleaseSpec = GradientReleaseSpec() + accumulation: AccumulationSpec = AccumulationSpec() + + +def with_updated_spec( + spec: TestSpec | CorrectnessSpec | GradientReleaseSpec | AccumulationSpec | None, + test_type: OptTestType | None = None, + tolerances_override: dict[torch.dtype, Tolerance] | None = None, +) -> TestSpec: + if isinstance(spec, (CorrectnessSpec, GradientReleaseSpec, AccumulationSpec)): + if isinstance(spec, CorrectnessSpec): + base = TestSpec(default=spec) + elif isinstance(spec, GradientReleaseSpec): + base = TestSpec(gradient_release=spec) + else: + base = TestSpec(accumulation=spec) + else: + base = spec or TestSpec() + + if tolerances_override is None: + tolerances_override = {} + + if test_type is None: + return base + + if test_type == OptTestType.default: + merged = {**base.default.tolerance, **tolerances_override} + return replace(base, default=replace(base.default, tolerance=merged)) + if test_type == OptTestType.gradient_release: + merged = {**base.gradient_release.baseline_tolerance, **tolerances_override} + return replace(base, gradient_release=replace(base.gradient_release, baseline_tolerance=merged)) + if test_type == OptTestType.accumulation: + merged = {**base.accumulation.tolerance, **tolerances_override} + return replace(base, accumulation=replace(base.accumulation, tolerance=merged)) + raise ValueError(f"Unknown test type: {test_type}") @dataclass @@ -135,22 +170,16 @@ class OptTest: fully_decoupled_reference: type[Optimizer] | None = None # Behavior / constraints - skip_tests: list[TestType] = field(default_factory=list) + skip_tests: list[OptTestType] = field(default_factory=list) any_precision: bool = False test_decoupled_wd: bool = True - custom_iterations: dict[TestType | tuple[TestType, DeviceType], int] | None = None - custom_tolerances: dict[torch.dtype, Tolerance] | None = None + custom_iterations: dict[OptTestType | tuple[OptTestType, DeviceType] | tuple[OptTestType, DeviceType, torch.dtype], int] | None = None + spec: TestSpec = field(default_factory=TestSpec) only_dtypes: list[torch.dtype] | None = None def __post_init__(self): if self.reference_params is None: self.reference_params = self.optimi_params - if self.custom_tolerances is None: - self.custom_tolerances = {} - # reasonable defaults; override per-case as needed - self.custom_tolerances.setdefault(torch.float32, Tolerance()) - self.custom_tolerances.setdefault(torch.bfloat16, Tolerance(atol=1e-3, rtol=1e-2, max_error_rate=0.01)) - self.custom_tolerances.setdefault(torch.float16, Tolerance(atol=1e-4, rtol=1e-3, max_error_rate=0.01)) @property def optimizer_name(self) -> str: @@ -210,7 +239,7 @@ def default_variants(base: OptTest) -> list[OptTest]: skip_tests=list(base.skip_tests), any_precision=base.any_precision, custom_iterations=base.custom_iterations, - custom_tolerances=base.custom_tolerances, + spec=base.spec, only_dtypes=base.only_dtypes, fully_decoupled_reference=base.fully_decoupled_reference, ) diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py index 39f5d78..b8e81f3 100644 --- a/test_new/opt_adan.py +++ b/test_new/opt_adan.py @@ -4,9 +4,10 @@ from typing import Any import optimi +import torch from tests import reference -from .config import BaseParams, OptTest +from .config import BaseParams, DeviceType, OptTest, OptTestType @dataclass @@ -31,6 +32,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(), reference_class=reference.Adan, reference_params=AdanParams(), + custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_weight_decay", @@ -38,6 +40,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-2), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2), + custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_adam_wd", @@ -45,6 +48,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-2, adam_wd=True), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2, weight_decouple=True), + custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_decoupled_lr", @@ -52,5 +56,6 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-5, decouple_lr=True), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2), + custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, ), ] diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py index adeaaa2..354de7b 100644 --- a/test_new/opt_anyadam.py +++ b/test_new/opt_anyadam.py @@ -6,7 +6,7 @@ import torch from tests.reference import AnyPrecisionAdamW -from .config import BaseParams, OptTest, Tolerance +from .config import BaseParams, OptTest, OptTestType, Tolerance, with_updated_spec @dataclass @@ -36,7 +36,11 @@ def to_reference_kwargs(self, reference_class: type) -> dict: reference_params=AnyAdamParams(betas=(0.9, 0.99), use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, + spec=with_updated_spec( + spec=None, + test_type=OptTestType.default, + tolerances_override={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01)}, + ), ), OptTest( name="anyadam_kahan_wd", @@ -46,7 +50,11 @@ def to_reference_kwargs(self, reference_class: type) -> dict: reference_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=0.01, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01, equal_nan=False)}, + spec=with_updated_spec( + spec=None, + test_type=OptTestType.default, + tolerances_override={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01)}, + ), ), OptTest( name="anyadam_kahan_decoupled_lr", @@ -56,6 +64,10 @@ def to_reference_kwargs(self, reference_class: type) -> dict: reference_params=AnyAdamParams(betas=(0.9, 0.99), weight_decay=1e-2, use_kahan_summation=True), only_dtypes=[torch.bfloat16], any_precision=True, - custom_tolerances={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01, equal_nan=False)}, + spec=with_updated_spec( + spec=None, + test_type=OptTestType.default, + tolerances_override={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01)}, + ), ), ] diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index ed4a026..5e0d964 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -6,7 +6,7 @@ import optimi import torch -from .config import BaseParams, OptTest, Tolerance +from .config import BaseParams, OptTest, OptTestType, Tolerance, with_updated_spec @dataclass @@ -26,6 +26,6 @@ def __post_init__(self): optimi_params=RAdamParams(), reference_class=torch.optim.RAdam, reference_params=RAdamParams(), - custom_tolerances={torch.float32: Tolerance(max_error_rate=0.001)}, + spec=with_updated_spec(spec=None, test_type=OptTestType.default, tolerances_override={torch.float32: Tolerance(max_error_rate=0.001)}), test_decoupled_wd="decoupled_weight_decay" in inspect.signature(torch.optim.RAdam.__init__).parameters, ) diff --git a/test_new/opt_ranger.py b/test_new/opt_ranger.py index 0ffe25c..771cd86 100644 --- a/test_new/opt_ranger.py +++ b/test_new/opt_ranger.py @@ -5,7 +5,7 @@ import optimi from tests import reference -from .config import BaseParams, OptTest, TestType +from .config import BaseParams, OptTest, OptTestType @dataclass @@ -24,6 +24,6 @@ class RangerParams(BaseParams): reference_class=reference.Ranger, reference_params=RangerParams(), # Match legacy longer gradient-release coverage due to Lookahead cadence. - custom_iterations={TestType.gradient_release: 160}, + custom_iterations={OptTestType.gradient_release: 160}, ) ] diff --git a/test_new/opt_sgd.py b/test_new/opt_sgd.py index 4528fff..5d0f055 100644 --- a/test_new/opt_sgd.py +++ b/test_new/opt_sgd.py @@ -7,7 +7,7 @@ import torch from tests import reference -from .config import BaseParams, OptTest, TestType +from .config import BaseParams, OptTest, OptTestType @dataclass @@ -32,7 +32,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=SGDParams(), reference_class=torch.optim.SGD, reference_params=SGDParams(), - skip_tests=[TestType.accumulation], + skip_tests=[OptTestType.accumulation], ), OptTest( name="sgd_momentum", @@ -54,7 +54,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=SGDParams(momentum=0.9, weight_decay=1e-2), reference_class=torch.optim.SGD, reference_params=SGDParams(momentum=0.9, weight_decay=1e-2), - skip_tests=[TestType.accumulation], + skip_tests=[OptTestType.accumulation], ), OptTest( name="sgd_decoupled_lr", diff --git a/test_new/runner.py b/test_new/runner.py new file mode 100644 index 0000000..9ede09c --- /dev/null +++ b/test_new/runner.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import io +import random + +import torch +from optimi import prepare_for_gradient_release, remove_gradient_release +from torch import Tensor + +from .config import Backend, DeviceType, OptTest, OptTestType, Tolerance + + +def _device_type(device: torch.device) -> DeviceType: + return DeviceType.cpu if device.type == "cpu" else DeviceType.gpu + + +def _get_iterations( + opttest: OptTest, + test_type: OptTestType, + default: int, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +) -> int: + if not opttest.custom_iterations: + return default + if device is not None: + key = (test_type, _device_type(device)) + if dtype is not None: + dtype_key = (test_type, _device_type(device), dtype) + if dtype_key in opttest.custom_iterations: + return opttest.custom_iterations[dtype_key] + if key in opttest.custom_iterations: + return opttest.custom_iterations[key] + return opttest.custom_iterations.get(test_type, default) + + +def assert_most_approx_close( + a: torch.Tensor, + b: torch.Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, + max_error_count: int = 0, + max_error_rate: float | None = None, + name: str = "", +) -> None: + """Assert that most values in two tensors are approximately close.""" + idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) + error_count = (idx == 0).sum().item() + + if max_error_rate is not None: + if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + elif error_count > max_error_count: + print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + + +class MLP(torch.nn.Module): + def __init__(self, input_size: int, hidden_size: int, device: torch.device, dtype: torch.dtype): + super().__init__() + self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) + self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) + self.act = torch.nn.Mish() + self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def run_test( + opttest: OptTest, + device: torch.device, + dtype: torch.dtype, + backend: Backend, + test_type: OptTestType, + dims: tuple[int, int] | None = None, +) -> None: + if test_type == OptTestType.default: + default_spec = opttest.spec.default + default_iters = default_spec.iterations_cpu if device.type == "cpu" else default_spec.iterations_gpu + iterations = _get_iterations(opttest, test_type, default_iters, device=device, dtype=dtype) + tolerance = default_spec.tolerance[dtype] + + if dims is None: + dim1, dim2 = (64, 128) if device.type == "cpu" else (256, 512) + else: + dim1, dim2 = dims + + batch_size = default_spec.batch_cpu if device.type == "cpu" else default_spec.batch_gpu + max_error_count = default_spec.max_error_cpu if device.type == "cpu" else default_spec.max_error_gpu + max_error_rate = tolerance.max_error_rate + + elif test_type == OptTestType.gradient_release: + gradient_spec = opttest.spec.gradient_release + iterations = _get_iterations(opttest, test_type, gradient_spec.iterations, device=device, dtype=dtype) + tolerance = gradient_spec.tolerance[dtype] + + dim1, dim2 = dims if dims is not None else (128, 256) + batch_size = gradient_spec.batch + max_error_count = gradient_spec.max_error_count + max_error_rate = tolerance.max_error_rate + + elif test_type == OptTestType.accumulation: + accumulation_spec = opttest.spec.accumulation + iterations = _get_iterations(opttest, test_type, accumulation_spec.iterations, device=device, dtype=dtype) + tolerance = accumulation_spec.tolerance[dtype] + dim1, dim2 = dims if dims is not None else (128, 256) + batch_size = accumulation_spec.batch + max_error_count = 0 + max_error_rate = accumulation_spec.max_error_rate + else: + raise ValueError(f"Unknown test type: {test_type}") + + m1 = MLP(dim1, dim2, device=device, dtype=dtype) + m2 = MLP(dim1, dim2, device=device, dtype=dtype) + m2.load_state_dict(m1.state_dict()) + + if test_type == OptTestType.gradient_release: + m3 = MLP(dim1, dim2, device=device, dtype=dtype) + m3.load_state_dict(m1.state_dict()) + else: + m3 = None + + if test_type == OptTestType.default and not opttest.any_precision and dtype != torch.float32: + for p in m1.parameters(): + p.data = p.data.float() + + reference_kwargs = opttest.to_reference_kwargs(backend) + optimi_kwargs = opttest.to_optimi_kwargs(backend) + reference_class = opttest.reference_class + + reference_optimizer = None + optimi_optimizer = None + torch_optimizers: dict[torch.nn.Parameter, torch.optim.Optimizer] | None = None + pytorch_hooks: list[torch.utils.hooks.RemovableHandle] = [] + + if test_type == OptTestType.default: + reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) + buffer = io.BytesIO() + elif test_type == OptTestType.gradient_release: + reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) + + def optimizer_hook(parameter) -> None: + assert torch_optimizers is not None + torch_optimizers[parameter].step() + torch_optimizers[parameter].zero_grad() + + torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} + for p in m2.parameters(): + pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) + + optimi_kwargs["gradient_release"] = True + optimi_optimizer = opttest.optimi_class(m3.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m3, optimi_optimizer) + else: + reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) + optimi_kwargs["gradient_release"] = True + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) + prepare_for_gradient_release(m2, optimi_optimizer) + gradient_accumulation_steps = accumulation_spec.gradient_accumulation_steps + + for i in range(iterations): + input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) + if test_type == OptTestType.default: + input2 = input1.detach().clone() + else: + input2 = input1.clone() + target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) + if test_type == OptTestType.default: + target2 = target1.detach().clone() + else: + target2 = target1.clone() + + if test_type == OptTestType.gradient_release: + input3 = input1.clone() + target3 = target1.clone() + else: + input3 = None + target3 = None + + if test_type == OptTestType.default and not opttest.any_precision and dtype != torch.float32: + input1 = input1.float() + target1 = target1.float() + + if test_type == OptTestType.accumulation: + optimi_optimizer.optimizer_accumulation = (i + 1) % gradient_accumulation_steps != 0 + + output1 = m1(input1) + output2 = m2(input2) + output3 = m3(input3) if m3 is not None else None + + loss1 = torch.nn.functional.mse_loss(output1, target1) + loss2 = torch.nn.functional.mse_loss(output2, target2) + loss3 = torch.nn.functional.mse_loss(output3, target3) if output3 is not None else None + + loss1.backward() + loss2.backward() + if loss3 is not None: + loss3.backward() + + if test_type == OptTestType.default: + reference_optimizer.step() + optimi_optimizer.step() + reference_optimizer.zero_grad() + optimi_optimizer.zero_grad() + elif test_type == OptTestType.gradient_release: + reference_optimizer.step() + reference_optimizer.zero_grad() + elif not optimi_optimizer.optimizer_accumulation: + reference_optimizer.step() + reference_optimizer.zero_grad() + + if test_type in (OptTestType.gradient_release, OptTestType.accumulation): + if random.random() < 0.5: + optimi_optimizer.step() + optimi_optimizer.zero_grad() + + if test_type == OptTestType.default: + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc1: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc2: ", + ) + + if i % max(1, iterations // 10) == 0 and i > 0: + torch.save(optimi_optimizer.state_dict(), buffer) + buffer.seek(0) + ckpt = torch.load(buffer, weights_only=True) + optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) + optimi_optimizer.load_state_dict(ckpt) + buffer.seek(0) + buffer.truncate(0) + + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc1 after load: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + atol=tolerance.atol, + rtol=tolerance.rtol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="fc2 after load: ", + ) + elif test_type == OptTestType.gradient_release: + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="PyTorch-PyTorch: ", + ) + assert_most_approx_close( + m1.fc1.weight, + m3.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="PyTorch-Optimi: ", + ) + assert_most_approx_close( + m1.fc2.weight, + m3.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + name="PyTorch-Optimi: ", + ) + + if test_type == OptTestType.accumulation: + assert_most_approx_close( + m1.fc1.weight, + m2.fc1.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + ) + assert_most_approx_close( + m1.fc2.weight, + m2.fc2.weight, + rtol=tolerance.rtol, + atol=tolerance.atol, + max_error_count=max_error_count, + max_error_rate=max_error_rate, + ) + + for h in pytorch_hooks: + h.remove() + if test_type == OptTestType.gradient_release: + remove_gradient_release(m3) + elif test_type == OptTestType.accumulation: + remove_gradient_release(m2) diff --git a/test_new/runners.py b/test_new/runners.py deleted file mode 100644 index c2dde32..0000000 --- a/test_new/runners.py +++ /dev/null @@ -1,376 +0,0 @@ -from __future__ import annotations - -import io -import random - -import torch -from optimi import prepare_for_gradient_release, remove_gradient_release -from torch import Tensor - -from .config import DEFAULTS, Backend, DeviceType, OptTest, TestType, Tolerance - - -def _device_type(device: torch.device) -> DeviceType: - return DeviceType.cpu if device.type == "cpu" else DeviceType.gpu - - -def _get_iterations( - opttest: OptTest, - test_type: TestType, - default: int, - device: torch.device | None = None, -) -> int: - if not opttest.custom_iterations: - return default - if device is not None: - key = (test_type, _device_type(device)) - if key in opttest.custom_iterations: - return opttest.custom_iterations[key] - return opttest.custom_iterations.get(test_type, default) - - -def assert_most_approx_close( - a: torch.Tensor, - b: torch.Tensor, - rtol: float = 1e-3, - atol: float = 1e-3, - max_error_count: int = 0, - max_error_rate: float | None = None, - name: str = "", -) -> None: - """Assert that most values in two tensors are approximately close. - - Allows for a small number of errors based on max_error_count and max_error_rate. - """ - idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) - error_count = (idx == 0).sum().item() - - if max_error_rate is not None: - if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) - elif error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) - - -class MLP(torch.nn.Module): - def __init__(self, input_size: int, hidden_size: int, device: torch.device, dtype: torch.dtype): - super().__init__() - self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) - self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) - self.act = torch.nn.Mish() - self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) - - def forward(self, x: Tensor) -> Tensor: - x = self.norm(x) - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return x - - -def run_correctness( - opttest: OptTest, - device: torch.device, - dtype: torch.dtype, - backend: Backend, - dims: tuple[int, int] | None = None, -) -> None: - # Iterations and tolerance - default_iters = DEFAULTS.correctness.cpu_iterations if device.type == "cpu" else DEFAULTS.correctness.gpu_iterations - iterations = _get_iterations(opttest, TestType.correctness, default_iters, device=device) - # Special-opttest: Adan bf16 on GPU - if device.type != "cpu" and dtype == torch.bfloat16 and opttest.optimizer_name == "adan": - iterations = DEFAULTS.correctness.adan_bf16_gpu_iterations - tolerance = opttest.custom_tolerances[dtype] - # Dims, batch, errors - if dims is not None: - dim1, dim2 = dims - elif device.type == "cpu": - dim1, dim2 = DEFAULTS.correctness.cpu_dims - else: - dim1, dim2 = DEFAULTS.correctness.gpu_dims - batch_size = DEFAULTS.correctness.cpu_batch_size if device.type == "cpu" else DEFAULTS.correctness.gpu_batch_size - max_error_count = DEFAULTS.correctness.cpu_max_error_count if device.type == "cpu" else DEFAULTS.correctness.gpu_max_error_count - - # Create models - m1 = MLP(dim1, dim2, device=device, dtype=dtype) - m2 = MLP(dim1, dim2, device=device, dtype=dtype) - m2.load_state_dict(m1.state_dict()) - - # Convert parameters to float for non-any_precision - if not opttest.any_precision and dtype != torch.float32: - for p in m1.parameters(): - p.data = p.data.float() - - # Optimizers - reference_class = opttest.reference_class - reference_kwargs = opttest.to_reference_kwargs(backend) - optimi_kwargs = opttest.to_optimi_kwargs(backend) - reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) - optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) - - buffer = io.BytesIO() - - for i in range(iterations): - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.detach().clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.detach().clone() - - if not opttest.any_precision and dtype != torch.float32: - input1 = input1.float() - target1 = target1.float() - - output1 = m1(input1) - output2 = m2(input2) - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - loss1.backward() - loss2.backward() - - reference_optimizer.step() - optimi_optimizer.step() - reference_optimizer.zero_grad() - optimi_optimizer.zero_grad() - - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc1: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc2: ", - ) - - # state_dict save/load periodically - if i % max(1, iterations // 10) == 0 and i > 0: - torch.save(optimi_optimizer.state_dict(), buffer) - buffer.seek(0) - ckpt = torch.load(buffer, weights_only=True) - optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) - optimi_optimizer.load_state_dict(ckpt) - buffer.seek(0) - buffer.truncate(0) - - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc1 after load: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - atol=tolerance.atol, - rtol=tolerance.rtol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="fc2 after load: ", - ) - - -def run_gradient_release( - opttest: OptTest, - device: torch.device, - dtype: torch.dtype, - backend: Backend, - dims: tuple[int, int] | None = None, -) -> None: - def optimizer_hook(parameter) -> None: - torch_optimizers[parameter].step() - torch_optimizers[parameter].zero_grad() - - # Iterations - iterations = _get_iterations(opttest, TestType.gradient_release, DEFAULTS.gradient_release.iterations, device=device) - - # Tolerances: merge baseline with per-opttest - tol = opttest.custom_tolerances[dtype] - baseline = DEFAULTS.gradient_release.baseline_tolerance.get(dtype, tol) - tolerance = Tolerance( - rtol=max(tol.rtol, baseline.rtol), - atol=max(tol.atol, baseline.atol), - max_error_rate=max(tol.max_error_rate, baseline.max_error_rate), - equal_nan=tol.equal_nan, - ) - - max_error_count = DEFAULTS.gradient_release.max_error_count - - # Dims and batch size - if dims is not None: - dim1, dim2 = dims - else: - dim1, dim2 = DEFAULTS.gradient_release.dims - - batch_size = DEFAULTS.gradient_release.batch_size - - m1 = MLP(dim1, dim2, device=device, dtype=dtype) # regular - m2 = MLP(dim1, dim2, device=device, dtype=dtype) # PyTorch hooks - m3 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi gradient release - m2.load_state_dict(m1.state_dict()) - m3.load_state_dict(m1.state_dict()) - - reference_class = opttest.reference_class - reference_kwargs = opttest.to_reference_kwargs(backend) - optimi_kwargs = opttest.to_optimi_kwargs(backend) - - regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) - torch_optimizers = {p: reference_class([p], **reference_kwargs) for p in m2.parameters()} - pytorch_hooks = [] - for p in m2.parameters(): - pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) - - optimi_kwargs["gradient_release"] = True - optimi_optimizer = opttest.optimi_class(m3.parameters(), **optimi_kwargs) - prepare_for_gradient_release(m3, optimi_optimizer) - - for _ in range(iterations): - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.clone() - input3 = input1.clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.clone() - target3 = target1.clone() - - output1 = m1(input1) - output2 = m2(input2) - output3 = m3(input3) - - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - loss3 = torch.nn.functional.mse_loss(output3, target3) - - loss1.backward() - loss2.backward() - loss3.backward() - - regular_optimizer.step() - regular_optimizer.zero_grad() - - # Random step/zero_grad to simulate using optimi's accumulation in a framework like Composer - if random.random() < 0.5: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - assert_most_approx_close( - m1.fc1.weight, - m2.fc1.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="PyTorch-PyTorch: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m2.fc2.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="PyTorch-PyTorch: ", - ) - assert_most_approx_close( - m1.fc1.weight, - m3.fc1.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="PyTorch-Optimi: ", - ) - assert_most_approx_close( - m1.fc2.weight, - m3.fc2.weight, - rtol=tolerance.rtol, - atol=tolerance.atol, - max_error_count=max_error_count, - max_error_rate=tolerance.max_error_rate, - name="PyTorch-Optimi: ", - ) - - for h in pytorch_hooks: - h.remove() - remove_gradient_release(m3) - - -def run_accumulation( - opttest: OptTest, - device: torch.device, - dtype: torch.dtype, - backend: Backend, - dims: tuple[int, int] | None = None, -) -> None: - # Iterations - iterations = _get_iterations(opttest, TestType.accumulation, DEFAULTS.accumulation.iterations, device=device) - - # Dims and batch size - if dims is not None: - dim1, dim2 = dims - else: - dim1, dim2 = DEFAULTS.accumulation.dims - - batch_size = DEFAULTS.accumulation.batch_size - - # Tolerance and error rate - tolerance = DEFAULTS.accumulation.tolerance - max_error_rate = DEFAULTS.accumulation.max_error_rate - - m1 = MLP(dim1, dim2, device=device, dtype=dtype) # Regular optimizer - m2 = MLP(dim1, dim2, device=device, dtype=dtype) # Optimi accumulation - m2.load_state_dict(m1.state_dict()) - - reference_class = opttest.reference_class - reference_kwargs = opttest.to_reference_kwargs(backend) - optimi_kwargs = opttest.to_optimi_kwargs(backend) - - regular_optimizer = reference_class(m1.parameters(), **reference_kwargs) - optimi_kwargs["gradient_release"] = True - optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) - prepare_for_gradient_release(m2, optimi_optimizer) - - gradient_accumulation_steps = DEFAULTS.accumulation.gradient_accumulation_steps - - for i in range(iterations): - input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - input2 = input1.clone() - target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - target2 = target1.clone() - - optimi_optimizer.optimizer_accumulation = (i + 1) % gradient_accumulation_steps != 0 - - output1 = m1(input1) - output2 = m2(input2) - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - - loss1.backward() - loss2.backward() - - if not optimi_optimizer.optimizer_accumulation: - regular_optimizer.step() - regular_optimizer.zero_grad() - - # Random step/zero_grad to simulate using optimi's accumulation in a framework like Composer - if random.random() < 0.5: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=tolerance.rtol, atol=tolerance.atol, max_error_rate=max_error_rate) - - remove_gradient_release(m2) diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index 9f08d9f..957130a 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -3,8 +3,8 @@ from _pytest.mark.structures import ParameterSet from optimi.utils import MIN_TORCH_2_6 -from .config import Backend, DeviceType, OptTest, TestType, discover_tests -from .runners import run_accumulation, run_correctness, run_gradient_release +from .config import Backend, DeviceType, OptTest, OptTestType, discover_tests +from .runner import run_test DEVICE_PARAMS = [ pytest.param(DeviceType.cpu, marks=pytest.mark.cpu, id=DeviceType.cpu.value), @@ -41,7 +41,7 @@ ] -def _should_skip(test_type: TestType, opttest: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: +def _should_skip(test_type: OptTestType, opttest: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: # Explicit per-opttest skip if test_type in set(opttest.skip_tests): return True @@ -75,7 +75,11 @@ def _should_skip(test_type: TestType, opttest: OptTest, device_type: DeviceType, return True # Gradient release and accumulation are GPU-only tests - if test_type in (TestType.gradient_release, TestType.accumulation) and device_type == DeviceType.cpu: + if test_type in (OptTestType.gradient_release, OptTestType.accumulation) and device_type == DeviceType.cpu: + return True + + # Gradient release / accumulation are incompatible with foreach + if test_type in (OptTestType.gradient_release, OptTestType.accumulation) and backend == Backend.foreach: return True # bfloat16 is not supported on MPS @@ -102,8 +106,8 @@ def _param_id(param: ParameterSet) -> str: return param.id or str(param.values[0]) -def _build_params(test_type: TestType) -> list[ParameterSet]: - if test_type == TestType.correctness: +def _build_params(test_type: OptTestType) -> list[ParameterSet]: + if test_type == OptTestType.default: device_params = DEVICE_PARAMS dtype_params = DTYPE_PARAMS dims_params = FULL_DIMS @@ -118,7 +122,7 @@ def _build_params(test_type: TestType) -> list[ParameterSet]: for dtype_param in dtype_params: for backend_param in BACKEND_PARAMS: for dims_param in dims_params: - if test_type == TestType.correctness and _param_value(dims_param)[0] != _param_value(device_param): + if test_type == OptTestType.default and _param_value(dims_param)[0] != _param_value(device_param): continue if _should_skip( test_type, @@ -151,18 +155,18 @@ def _build_params(test_type: TestType) -> list[ParameterSet]: return params -@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(TestType.correctness)) -def test_correctness(opttest, device_type, dtype, backend, dims_spec, gpu_device): +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(OptTestType.default)) +def test_default(opttest, device_type, dtype, backend, dims_spec, gpu_device): _, dims = dims_spec device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") - run_correctness(opttest, device, dtype, backend, dims=dims) + run_test(opttest, device, dtype, backend, OptTestType.default, dims=dims) -@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(TestType.gradient_release)) +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.gradient_release)) def test_gradient_release(opttest, device_type, dtype, backend, dims, gpu_device): - run_gradient_release(opttest, torch.device(gpu_device), dtype, backend, dims=dims) + run_test(opttest, torch.device(gpu_device), dtype, backend, OptTestType.gradient_release, dims=dims) -@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(TestType.accumulation)) +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.accumulation)) def test_accumulation(opttest, device_type, dtype, backend, dims, gpu_device): - run_accumulation(opttest, torch.device(gpu_device), dtype, backend, dims=dims) + run_test(opttest, torch.device(gpu_device), dtype, backend, OptTestType.accumulation, dims=dims) From 03303de042f41b48a7c3de061cb3567313a8d92a Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 11:33:26 -0600 Subject: [PATCH 11/16] fixes to last commit --- test_new/config.py | 28 ++++++++++----------- test_new/opt_adan.py | 8 +++--- test_new/opt_anyadam.py | 6 ++--- test_new/opt_radam.py | 2 +- test_new/runner.py | 38 ++++++++++++++--------------- test_new/test_optimizers_unified.py | 13 +++++----- 6 files changed, 46 insertions(+), 49 deletions(-) diff --git a/test_new/config.py b/test_new/config.py index 116531f..06dbc21 100644 --- a/test_new/config.py +++ b/test_new/config.py @@ -14,7 +14,7 @@ class OptTestType(Enum): - default = "default" + normal = "normal" gradient_release = "gradient_release" accumulation = "accumulation" @@ -39,7 +39,7 @@ class Tolerance: @dataclass() -class CorrectnessSpec: +class NormalSpec: iterations_cpu: int = 20 iterations_gpu: int = 40 batch_cpu: int = 1 @@ -89,18 +89,18 @@ class AccumulationSpec: @dataclass() class TestSpec: - default: CorrectnessSpec = CorrectnessSpec() - gradient_release: GradientReleaseSpec = GradientReleaseSpec() - accumulation: AccumulationSpec = AccumulationSpec() + normal: NormalSpec = field(default_factory=NormalSpec) + gradient_release: GradientReleaseSpec = field(default_factory=GradientReleaseSpec) + accumulation: AccumulationSpec = field(default_factory=AccumulationSpec) def with_updated_spec( - spec: TestSpec | CorrectnessSpec | GradientReleaseSpec | AccumulationSpec | None, + spec: TestSpec | NormalSpec | GradientReleaseSpec | AccumulationSpec | None, test_type: OptTestType | None = None, tolerances_override: dict[torch.dtype, Tolerance] | None = None, ) -> TestSpec: - if isinstance(spec, (CorrectnessSpec, GradientReleaseSpec, AccumulationSpec)): - if isinstance(spec, CorrectnessSpec): + if isinstance(spec, (NormalSpec, GradientReleaseSpec, AccumulationSpec)): + if isinstance(spec, NormalSpec): base = TestSpec(default=spec) elif isinstance(spec, GradientReleaseSpec): base = TestSpec(gradient_release=spec) @@ -115,15 +115,13 @@ def with_updated_spec( if test_type is None: return base - if test_type == OptTestType.default: - merged = {**base.default.tolerance, **tolerances_override} - return replace(base, default=replace(base.default, tolerance=merged)) + if test_type == OptTestType.normal: + merged = {**base.normal.tolerance, **tolerances_override} + return replace(base, normal=replace(base.normal, tolerance=merged)) if test_type == OptTestType.gradient_release: - merged = {**base.gradient_release.baseline_tolerance, **tolerances_override} - return replace(base, gradient_release=replace(base.gradient_release, baseline_tolerance=merged)) + return replace(base, gradient_release=replace(base.gradient_release, **tolerances_override)) if test_type == OptTestType.accumulation: - merged = {**base.accumulation.tolerance, **tolerances_override} - return replace(base, accumulation=replace(base.accumulation, tolerance=merged)) + return replace(base, accumulation=replace(base.accumulation, **tolerances_override)) raise ValueError(f"Unknown test type: {test_type}") diff --git a/test_new/opt_adan.py b/test_new/opt_adan.py index b8e81f3..524d5c7 100644 --- a/test_new/opt_adan.py +++ b/test_new/opt_adan.py @@ -32,7 +32,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(), reference_class=reference.Adan, reference_params=AdanParams(), - custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, + custom_iterations={(OptTestType.normal, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_weight_decay", @@ -40,7 +40,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-2), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2), - custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, + custom_iterations={(OptTestType.normal, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_adam_wd", @@ -48,7 +48,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-2, adam_wd=True), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2, weight_decouple=True), - custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, + custom_iterations={(OptTestType.normal, DeviceType.gpu, torch.bfloat16): 20}, ), OptTest( name="adan_decoupled_lr", @@ -56,6 +56,6 @@ def to_reference_kwargs(self, reference_class: type) -> dict[str, Any]: optimi_params=AdanParams(weight_decay=2e-5, decouple_lr=True), reference_class=reference.Adan, reference_params=AdanParams(weight_decay=2e-2), - custom_iterations={(OptTestType.default, DeviceType.gpu, torch.bfloat16): 20}, + custom_iterations={(OptTestType.normal, DeviceType.gpu, torch.bfloat16): 20}, ), ] diff --git a/test_new/opt_anyadam.py b/test_new/opt_anyadam.py index 354de7b..17d87ab 100644 --- a/test_new/opt_anyadam.py +++ b/test_new/opt_anyadam.py @@ -38,7 +38,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: any_precision=True, spec=with_updated_spec( spec=None, - test_type=OptTestType.default, + test_type=OptTestType.normal, tolerances_override={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01)}, ), ), @@ -52,7 +52,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: any_precision=True, spec=with_updated_spec( spec=None, - test_type=OptTestType.default, + test_type=OptTestType.normal, tolerances_override={torch.bfloat16: Tolerance(rtol=5e-2, atol=1e-2, max_error_rate=0.01)}, ), ), @@ -66,7 +66,7 @@ def to_reference_kwargs(self, reference_class: type) -> dict: any_precision=True, spec=with_updated_spec( spec=None, - test_type=OptTestType.default, + test_type=OptTestType.normal, tolerances_override={torch.bfloat16: Tolerance(rtol=2e-2, atol=2e-3, max_error_rate=0.01)}, ), ), diff --git a/test_new/opt_radam.py b/test_new/opt_radam.py index 5e0d964..08aadf8 100644 --- a/test_new/opt_radam.py +++ b/test_new/opt_radam.py @@ -26,6 +26,6 @@ def __post_init__(self): optimi_params=RAdamParams(), reference_class=torch.optim.RAdam, reference_params=RAdamParams(), - spec=with_updated_spec(spec=None, test_type=OptTestType.default, tolerances_override={torch.float32: Tolerance(max_error_rate=0.001)}), + spec=with_updated_spec(spec=None, test_type=OptTestType.normal, tolerances_override={torch.float32: Tolerance(max_error_rate=0.001)}), test_decoupled_wd="decoupled_weight_decay" in inspect.signature(torch.optim.RAdam.__init__).parameters, ) diff --git a/test_new/runner.py b/test_new/runner.py index 9ede09c..f0cecba 100644 --- a/test_new/runner.py +++ b/test_new/runner.py @@ -7,7 +7,7 @@ from optimi import prepare_for_gradient_release, remove_gradient_release from torch import Tensor -from .config import Backend, DeviceType, OptTest, OptTestType, Tolerance +from .config import Backend, DeviceType, OptTest, OptTestType def _device_type(device: torch.device) -> DeviceType: @@ -49,11 +49,11 @@ def assert_most_approx_close( if max_error_rate is not None: if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + msg = f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}" + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol, msg=msg) elif error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + msg = f"{name}Too many values not close: assert {error_count} < {max_error_count}" + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol, msg=msg) class MLP(torch.nn.Module): @@ -80,19 +80,19 @@ def run_test( test_type: OptTestType, dims: tuple[int, int] | None = None, ) -> None: - if test_type == OptTestType.default: - default_spec = opttest.spec.default - default_iters = default_spec.iterations_cpu if device.type == "cpu" else default_spec.iterations_gpu - iterations = _get_iterations(opttest, test_type, default_iters, device=device, dtype=dtype) - tolerance = default_spec.tolerance[dtype] + if test_type == OptTestType.normal: + normal_spec = opttest.spec.normal + normal_iters = normal_spec.iterations_cpu if device.type == "cpu" else normal_spec.iterations_gpu + iterations = _get_iterations(opttest, test_type, normal_iters, device=device, dtype=dtype) + tolerance = normal_spec.tolerance[dtype] if dims is None: dim1, dim2 = (64, 128) if device.type == "cpu" else (256, 512) else: dim1, dim2 = dims - batch_size = default_spec.batch_cpu if device.type == "cpu" else default_spec.batch_gpu - max_error_count = default_spec.max_error_cpu if device.type == "cpu" else default_spec.max_error_gpu + batch_size = normal_spec.batch_cpu if device.type == "cpu" else normal_spec.batch_gpu + max_error_count = normal_spec.max_error_cpu if device.type == "cpu" else normal_spec.max_error_gpu max_error_rate = tolerance.max_error_rate elif test_type == OptTestType.gradient_release: @@ -126,7 +126,7 @@ def run_test( else: m3 = None - if test_type == OptTestType.default and not opttest.any_precision and dtype != torch.float32: + if test_type == OptTestType.normal and not opttest.any_precision and dtype != torch.float32: for p in m1.parameters(): p.data = p.data.float() @@ -139,7 +139,7 @@ def run_test( torch_optimizers: dict[torch.nn.Parameter, torch.optim.Optimizer] | None = None pytorch_hooks: list[torch.utils.hooks.RemovableHandle] = [] - if test_type == OptTestType.default: + if test_type == OptTestType.normal: reference_optimizer = reference_class(m1.parameters(), **reference_kwargs) optimi_optimizer = opttest.optimi_class(m2.parameters(), **optimi_kwargs) buffer = io.BytesIO() @@ -167,12 +167,12 @@ def optimizer_hook(parameter) -> None: for i in range(iterations): input1 = torch.randn(batch_size, dim1, device=device, dtype=dtype) - if test_type == OptTestType.default: + if test_type == OptTestType.normal: input2 = input1.detach().clone() else: input2 = input1.clone() target1 = torch.randn(batch_size, 1, device=device, dtype=dtype) - if test_type == OptTestType.default: + if test_type == OptTestType.normal: target2 = target1.detach().clone() else: target2 = target1.clone() @@ -184,7 +184,7 @@ def optimizer_hook(parameter) -> None: input3 = None target3 = None - if test_type == OptTestType.default and not opttest.any_precision and dtype != torch.float32: + if test_type == OptTestType.normal and not opttest.any_precision and dtype != torch.float32: input1 = input1.float() target1 = target1.float() @@ -204,7 +204,7 @@ def optimizer_hook(parameter) -> None: if loss3 is not None: loss3.backward() - if test_type == OptTestType.default: + if test_type == OptTestType.normal: reference_optimizer.step() optimi_optimizer.step() reference_optimizer.zero_grad() @@ -221,7 +221,7 @@ def optimizer_hook(parameter) -> None: optimi_optimizer.step() optimi_optimizer.zero_grad() - if test_type == OptTestType.default: + if test_type == OptTestType.normal: assert_most_approx_close( m1.fc1.weight, m2.fc1.weight, diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index 957130a..17230ec 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -23,8 +23,7 @@ # Attach per-optimizer marks so users can -m adam, -m sgd, etc. OPTIMIZERS = [pytest.param(c, id=c.name, marks=getattr(pytest.mark, c.optimizer_name)) for c in discover_tests()] -# Dimension parameter spaces (match legacy tests) -# Correctness dims: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) +# Full dimensions: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) FULL_DIMS = [ pytest.param((DeviceType.cpu, (64, 64)), id="cpu-64x64"), pytest.param((DeviceType.cpu, (64, 128)), id="cpu-64x128"), @@ -107,7 +106,7 @@ def _param_id(param: ParameterSet) -> str: def _build_params(test_type: OptTestType) -> list[ParameterSet]: - if test_type == OptTestType.default: + if test_type == OptTestType.normal: device_params = DEVICE_PARAMS dtype_params = DTYPE_PARAMS dims_params = FULL_DIMS @@ -122,7 +121,7 @@ def _build_params(test_type: OptTestType) -> list[ParameterSet]: for dtype_param in dtype_params: for backend_param in BACKEND_PARAMS: for dims_param in dims_params: - if test_type == OptTestType.default and _param_value(dims_param)[0] != _param_value(device_param): + if test_type == OptTestType.normal and _param_value(dims_param)[0] != _param_value(device_param): continue if _should_skip( test_type, @@ -155,11 +154,11 @@ def _build_params(test_type: OptTestType) -> list[ParameterSet]: return params -@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(OptTestType.default)) -def test_default(opttest, device_type, dtype, backend, dims_spec, gpu_device): +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(OptTestType.normal)) +def test_normal(opttest, device_type, dtype, backend, dims_spec, gpu_device): _, dims = dims_spec device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") - run_test(opttest, device, dtype, backend, OptTestType.default, dims=dims) + run_test(opttest, device, dtype, backend, OptTestType.normal, dims=dims) @pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.gradient_release)) From 647cf5bcdd624937c23f49f96ecadd909d5e5c0d Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 11:50:55 -0600 Subject: [PATCH 12/16] refactor skip logic --- test_new/config.py | 32 +++++++++++++++++++ test_new/test_optimizers_unified.py | 48 ++++++++++------------------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/test_new/config.py b/test_new/config.py index 06dbc21..0265251 100644 --- a/test_new/config.py +++ b/test_new/config.py @@ -13,6 +13,9 @@ from torch.optim import Optimizer +from optimi.utils import MIN_TORCH_2_6 + + class OptTestType(Enum): normal = "normal" gradient_release = "gradient_release" @@ -23,12 +26,39 @@ class DeviceType(Enum): cpu = "cpu" gpu = "gpu" + def is_available(self) -> bool: + if self == DeviceType.cpu: + return True + if self == DeviceType.gpu: + return ( + torch.cuda.is_available() + or (hasattr(torch, "xpu") and torch.xpu.is_available()) + or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + ) + return False + class Backend(Enum): torch = "torch" triton = "triton" foreach = "foreach" + def is_supported(self, device: DeviceType) -> bool: + if self == Backend.triton: + # Triton requires torch >= 2.6 + if not MIN_TORCH_2_6: + return False + # Triton not supported on CPU + if device == DeviceType.cpu: + return False + # Triton not supported on MPS + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return False + # Triton requires GPU/XPU + if not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())): + return False + return True + @dataclass class Tolerance: @@ -223,6 +253,8 @@ def supports_l2_weight_decay(self) -> bool: return "decouple_wd" in inspect.signature(self.optimi_class.__init__).parameters + + def default_variants(base: OptTest) -> list[OptTest]: """Generate base + L2 + decoupled variants with minimal boilerplate.""" out: list[OptTest] = [] diff --git a/test_new/test_optimizers_unified.py b/test_new/test_optimizers_unified.py index 17230ec..66fb8cc 100644 --- a/test_new/test_optimizers_unified.py +++ b/test_new/test_optimizers_unified.py @@ -1,7 +1,7 @@ import pytest import torch from _pytest.mark.structures import ParameterSet -from optimi.utils import MIN_TORCH_2_6 + from .config import Backend, DeviceType, OptTest, OptTestType, discover_tests from .runner import run_test @@ -41,47 +41,27 @@ def _should_skip(test_type: OptTestType, opttest: OptTest, device_type: DeviceType, dtype: torch.dtype, backend: Backend) -> bool: - # Explicit per-opttest skip - if test_type in set(opttest.skip_tests): - return True - - # Respect per-test dtype constraints if provided - if opttest.only_dtypes and dtype not in opttest.only_dtypes: + # 1. Hardware availability + if not device_type.is_available(): return True - # Skip triton on CPU - if backend == Backend.triton and device_type == DeviceType.cpu: + # 2. Backend support for hardware + if not backend.is_supported(device_type): return True - # Triton requires torch >= 2.6 - if backend == Backend.triton and not MIN_TORCH_2_6: - return True - - # Triton is not supported on MPS - if ( - backend == Backend.triton - and not (torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())) - and (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) - ): + # 3. Explicit per-opttest skip + if test_type in set(opttest.skip_tests): return True - # GPU availability - if device_type == DeviceType.gpu and not ( - torch.cuda.is_available() - or (hasattr(torch, "xpu") and torch.xpu.is_available()) - or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) - ): + # 4. Respect per-test dtype constraints if provided + if opttest.only_dtypes and dtype not in opttest.only_dtypes: return True - # Gradient release and accumulation are GPU-only tests + # 5. Gradient release and accumulation are GPU-only tests if test_type in (OptTestType.gradient_release, OptTestType.accumulation) and device_type == DeviceType.cpu: return True - # Gradient release / accumulation are incompatible with foreach - if test_type in (OptTestType.gradient_release, OptTestType.accumulation) and backend == Backend.foreach: - return True - - # bfloat16 is not supported on MPS + # 6. bfloat16 is not supported on MPS if ( device_type == DeviceType.gpu and dtype == torch.bfloat16 @@ -90,10 +70,14 @@ def _should_skip(test_type: OptTestType, opttest: OptTest, device_type: DeviceTy ): return True - # Skip bfloat16 on CPU for most optimizers; allow anyadam exception via opttest.any_precision + # 7. Skip bfloat16 on CPU for most optimizers; allow anyadam exception via opttest.any_precision if device_type == DeviceType.cpu and dtype == torch.bfloat16 and not opttest.any_precision: return True + # 8. Skip foreach for gradient release and accumulation tests + if test_type != OptTestType.normal and backend == Backend.foreach: + return True + return False From 24a60fe01d05e5504b6682597d927fd360da6218 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 12:45:35 -0600 Subject: [PATCH 13/16] fixes --- test_new/config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test_new/config.py b/test_new/config.py index 0265251..28446ae 100644 --- a/test_new/config.py +++ b/test_new/config.py @@ -131,7 +131,7 @@ def with_updated_spec( ) -> TestSpec: if isinstance(spec, (NormalSpec, GradientReleaseSpec, AccumulationSpec)): if isinstance(spec, NormalSpec): - base = TestSpec(default=spec) + base = TestSpec(normal=spec) elif isinstance(spec, GradientReleaseSpec): base = TestSpec(gradient_release=spec) else: @@ -149,9 +149,11 @@ def with_updated_spec( merged = {**base.normal.tolerance, **tolerances_override} return replace(base, normal=replace(base.normal, tolerance=merged)) if test_type == OptTestType.gradient_release: - return replace(base, gradient_release=replace(base.gradient_release, **tolerances_override)) + merged = {**base.gradient_release.tolerance, **tolerances_override} + return replace(base, gradient_release=replace(base.gradient_release, tolerance=merged)) if test_type == OptTestType.accumulation: - return replace(base, accumulation=replace(base.accumulation, **tolerances_override)) + merged = {**base.accumulation.tolerance, **tolerances_override} + return replace(base, accumulation=replace(base.accumulation, tolerance=merged)) raise ValueError(f"Unknown test type: {test_type}") From 13f82ddc4dfd6fb2bee7aa73928e4c838d81d793 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 12:59:06 -0600 Subject: [PATCH 14/16] switch to new tests --- pyproject.toml | 5 +- test_new/__init__.py | 1 - test_new/conftest.py | 75 ---- tests/adam_test.py | 75 ---- tests/adan_test.py | 73 ---- tests/anyadam_test.py | 48 --- {test_new => tests}/config.py | 0 tests/conftest.py | 77 +++- tests/lion_test.py | 68 --- {test_new => tests}/opt_adam.py | 0 {test_new => tests}/opt_adamw.py | 0 {test_new => tests}/opt_adan.py | 0 {test_new => tests}/opt_anyadam.py | 0 {test_new => tests}/opt_lion.py | 0 {test_new => tests}/opt_radam.py | 0 {test_new => tests}/opt_ranger.py | 0 {test_new => tests}/opt_sgd.py | 0 {test_new => tests}/opt_stableadamw.py | 0 tests/optimizer_test.py | 400 ------------------ tests/radam_test.py | 75 ---- tests/ranger_test.py | 67 --- {test_new => tests}/runner.py | 0 tests/sgd_test.py | 77 ---- tests/stableadam_test.py | 68 --- .../test_optimizers.py | 0 25 files changed, 71 insertions(+), 1038 deletions(-) delete mode 100644 test_new/__init__.py delete mode 100644 test_new/conftest.py delete mode 100644 tests/adam_test.py delete mode 100644 tests/adan_test.py delete mode 100644 tests/anyadam_test.py rename {test_new => tests}/config.py (100%) delete mode 100644 tests/lion_test.py rename {test_new => tests}/opt_adam.py (100%) rename {test_new => tests}/opt_adamw.py (100%) rename {test_new => tests}/opt_adan.py (100%) rename {test_new => tests}/opt_anyadam.py (100%) rename {test_new => tests}/opt_lion.py (100%) rename {test_new => tests}/opt_radam.py (100%) rename {test_new => tests}/opt_ranger.py (100%) rename {test_new => tests}/opt_sgd.py (100%) rename {test_new => tests}/opt_stableadamw.py (100%) delete mode 100644 tests/optimizer_test.py delete mode 100644 tests/radam_test.py delete mode 100644 tests/ranger_test.py rename {test_new => tests}/runner.py (100%) delete mode 100644 tests/sgd_test.py delete mode 100644 tests/stableadam_test.py rename test_new/test_optimizers_unified.py => tests/test_optimizers.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 589eea9..3e242ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,9 @@ markers = [ "sgd", "stableadam", ] +filterwarnings = [ + "ignore:Parameter `foreach` is deprecated:DeprecationWarning", +] [tool.ruff] line-length = 140 @@ -157,4 +160,4 @@ explicit = true [[tool.uv.index]] name = "pytorch-xpu" url = "https://download.pytorch.org/whl/xpu" -explicit = true \ No newline at end of file +explicit = true diff --git a/test_new/__init__.py b/test_new/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/test_new/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test_new/conftest.py b/test_new/conftest.py deleted file mode 100644 index eda353d..0000000 --- a/test_new/conftest.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Pytest configuration and fixtures for the unified optimizer test framework. - -This module provides pytest configuration, custom mark registration, and the -`gpu_device` fixture used by tests. -""" - -import pytest -import torch - -from .config import optimizer_names - - -def pytest_configure(config): - "Configure pytest with custom marks for optimizer testing." - # Register device marks - config.addinivalue_line("markers", "cpu: mark test to run on CPU") - config.addinivalue_line("markers", "gpu: mark test to run on GPU") - - # Register dtype marks - config.addinivalue_line("markers", "float32: mark test to run with float32 dtype") - config.addinivalue_line("markers", "bfloat16: mark test to run with bfloat16 dtype") - - # Register backend marks - config.addinivalue_line("markers", "torch: mark test to run with torch backend") - config.addinivalue_line("markers", "foreach: mark test to run with foreach backend") - config.addinivalue_line("markers", "triton: mark test to run with triton backend") - - # Per-optimizer marks (e.g., -m adam, -m sgd) - for opt_name in optimizer_names(): - config.addinivalue_line("markers", f"{opt_name}: mark test for {opt_name} optimizer") - - -def pytest_addoption(parser): - "Add command-line option to specify a single GPU." - parser.addoption("--gpu-id", action="store", type=int, default=None, help="Specify a single GPU to use (e.g. --gpu-id=0)") - - -@pytest.fixture() -def gpu_device(worker_id, request): - """Map xdist workers to available GPU devices in a round-robin fashion, - supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. - Use a single specified GPU if --gpu-id is provided""" - - # Check if specific GPU was requested - specific_gpu = request.config.getoption("--gpu-id") - - # Determine available GPU backend and device count - if torch.cuda.is_available(): - backend = "cuda" - device_count = torch.cuda.device_count() - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - backend = "xpu" - device_count = torch.xpu.device_count() - elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - backend = "mps" - device_count = 0 - else: - raise RuntimeError("No GPU backend available") - - if specific_gpu is not None: - return torch.device(f"{backend}:{specific_gpu}") - - if worker_id == "master": - return torch.device(backend) - - # If no devices available, return default backend - if device_count == 0: - return torch.device(backend) - - # Extract worker number from worker_id (e.g., 'gw6' -> 6) - worker_num = int(worker_id.replace("gw", "")) - - # Map worker to GPU index using modulo to round-robin - gpu_idx = (worker_num - 1) % device_count - return torch.device(f"{backend}:{gpu_idx}") diff --git a/tests/adam_test.py b/tests/adam_test.py deleted file mode 100644 index 022d698..0000000 --- a/tests/adam_test.py +++ /dev/null @@ -1,75 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - - -optimizers = {} - -optimizers["adam"] = ({'optim':torch.optim.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0)}, - {'optim':optimi.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0)}) - -optimizers["adam_l2"] = ({'optim':torch.optim.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}, - {'optim':optimi.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2, decouple_wd=False)}) - -optimizers["adam_dw"] = ({'optim':torch.optim.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}, - {'optim':optimi.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2, decouple_wd=True)}) - -optimizers["adamw"] = ({'optim':torch.optim.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}, - {'optim':optimi.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}) - -optimizers["adamw_dlr"] = ({'optim':reference.DecoupledAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5)}, - {'optim':optimi.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5, decouple_lr=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - if optim_name in ["adam_l2"]: - pytest.skip("Skip tests for Adam with L2 weight decay.") - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/tests/adan_test.py b/tests/adan_test.py deleted file mode 100644 index 6086452..0000000 --- a/tests/adan_test.py +++ /dev/null @@ -1,73 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - - - -optimizers = {} - -optimizers["adan"] = ({'optim':reference.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6)}, - {'optim':optimi.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=0)}) - -optimizers["adan_wd"] = ({'optim':reference.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2)}, - {'optim':optimi.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2)}) - -optimizers["adan_awd"] = ({'optim':reference.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, no_prox=True)}, - {'optim':optimi.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2, adam_wd=True)}) - -optimizers["adan_dlr"] = ({'optim':reference.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-2)}, - {'optim':optimi.Adan, 'kwargs':dict(lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-6, weight_decay=2e-5, decouple_lr=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.adan -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.adan -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - # Adan bfloat16 updates are noisier, so GPU uses fewer test iterations - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer, - iterations=20 if dtype == torch.bfloat16 else None) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.adan -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.adan -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) diff --git a/tests/anyadam_test.py b/tests/anyadam_test.py deleted file mode 100644 index 3d51011..0000000 --- a/tests/anyadam_test.py +++ /dev/null @@ -1,48 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import buffer, run_optimizer, cpu_dim1, cpu_dim2, cpu_dtype, cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gpu_device - - - -optimizers = {} - -optimizers["any_adam"] = ({'optim':reference.AnyPrecisionAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0)}, - {'optim':optimi.Adam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, kahan_sum=True)}) - -optimizers["any_adamw"] = ({'optim':reference.AnyPrecisionAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}, - {'optim':optimi.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2, kahan_sum=True)}) - -optimizers["any_adamw_dlr"] = ({'optim':reference.AnyPrecisionAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-2)}, - {'optim':optimi.AdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-6, weight_decay=1e-5, decouple_lr=True, kahan_sum=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_dtype = [torch.bfloat16] -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer, any_precision=True) - - - -gpu_dtype = [torch.bfloat16] -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.adam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer, any_precision=True) \ No newline at end of file diff --git a/test_new/config.py b/tests/config.py similarity index 100% rename from test_new/config.py rename to tests/config.py diff --git a/tests/conftest.py b/tests/conftest.py index 07c912e..eda353d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,75 @@ +"""Pytest configuration and fixtures for the unified optimizer test framework. + +This module provides pytest configuration, custom mark registration, and the +`gpu_device` fixture used by tests. +""" + +import pytest import torch -try: - import triton -except ImportError: - triton = None +from .config import optimizer_names -def pytest_report_header(config): - if triton is None: - return f"libaries: PyTorch {torch.__version__}" - else: - return f"libaries: PyTorch {torch.__version__}, Triton: {triton.__version__}" +def pytest_configure(config): + "Configure pytest with custom marks for optimizer testing." + # Register device marks + config.addinivalue_line("markers", "cpu: mark test to run on CPU") + config.addinivalue_line("markers", "gpu: mark test to run on GPU") + + # Register dtype marks + config.addinivalue_line("markers", "float32: mark test to run with float32 dtype") + config.addinivalue_line("markers", "bfloat16: mark test to run with bfloat16 dtype") + + # Register backend marks + config.addinivalue_line("markers", "torch: mark test to run with torch backend") + config.addinivalue_line("markers", "foreach: mark test to run with foreach backend") + config.addinivalue_line("markers", "triton: mark test to run with triton backend") + + # Per-optimizer marks (e.g., -m adam, -m sgd) + for opt_name in optimizer_names(): + config.addinivalue_line("markers", f"{opt_name}: mark test for {opt_name} optimizer") def pytest_addoption(parser): - """Add command-line option to specify a single GPU""" + "Add command-line option to specify a single GPU." parser.addoption("--gpu-id", action="store", type=int, default=None, help="Specify a single GPU to use (e.g. --gpu-id=0)") + + +@pytest.fixture() +def gpu_device(worker_id, request): + """Map xdist workers to available GPU devices in a round-robin fashion, + supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. + Use a single specified GPU if --gpu-id is provided""" + + # Check if specific GPU was requested + specific_gpu = request.config.getoption("--gpu-id") + + # Determine available GPU backend and device count + if torch.cuda.is_available(): + backend = "cuda" + device_count = torch.cuda.device_count() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + backend = "xpu" + device_count = torch.xpu.device_count() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + backend = "mps" + device_count = 0 + else: + raise RuntimeError("No GPU backend available") + + if specific_gpu is not None: + return torch.device(f"{backend}:{specific_gpu}") + + if worker_id == "master": + return torch.device(backend) + + # If no devices available, return default backend + if device_count == 0: + return torch.device(backend) + + # Extract worker number from worker_id (e.g., 'gw6' -> 6) + worker_num = int(worker_id.replace("gw", "")) + + # Map worker to GPU index using modulo to round-robin + gpu_idx = (worker_num - 1) % device_count + return torch.device(f"{backend}:{gpu_idx}") diff --git a/tests/lion_test.py b/tests/lion_test.py deleted file mode 100644 index f35e783..0000000 --- a/tests/lion_test.py +++ /dev/null @@ -1,68 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - - - -optimizers = {} - -optimizers["lion"] = ({'optim':reference.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)}, - {'optim':optimi.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)}) - -optimizers["lion_wd"] = ({'optim':reference.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1)}, - {'optim':optimi.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1)}) - -optimizers["lion_dlr"] = ({'optim':reference.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=0.1)}, - {'optim':optimi.Lion, 'kwargs':dict(lr=1e-4, betas=(0.9, 0.99), weight_decay=1e-5, decouple_lr=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.lion -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.lion -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.lion -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.lion -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) \ No newline at end of file diff --git a/test_new/opt_adam.py b/tests/opt_adam.py similarity index 100% rename from test_new/opt_adam.py rename to tests/opt_adam.py diff --git a/test_new/opt_adamw.py b/tests/opt_adamw.py similarity index 100% rename from test_new/opt_adamw.py rename to tests/opt_adamw.py diff --git a/test_new/opt_adan.py b/tests/opt_adan.py similarity index 100% rename from test_new/opt_adan.py rename to tests/opt_adan.py diff --git a/test_new/opt_anyadam.py b/tests/opt_anyadam.py similarity index 100% rename from test_new/opt_anyadam.py rename to tests/opt_anyadam.py diff --git a/test_new/opt_lion.py b/tests/opt_lion.py similarity index 100% rename from test_new/opt_lion.py rename to tests/opt_lion.py diff --git a/test_new/opt_radam.py b/tests/opt_radam.py similarity index 100% rename from test_new/opt_radam.py rename to tests/opt_radam.py diff --git a/test_new/opt_ranger.py b/tests/opt_ranger.py similarity index 100% rename from test_new/opt_ranger.py rename to tests/opt_ranger.py diff --git a/test_new/opt_sgd.py b/tests/opt_sgd.py similarity index 100% rename from test_new/opt_sgd.py rename to tests/opt_sgd.py diff --git a/test_new/opt_stableadamw.py b/tests/opt_stableadamw.py similarity index 100% rename from test_new/opt_stableadamw.py rename to tests/opt_stableadamw.py diff --git a/tests/optimizer_test.py b/tests/optimizer_test.py deleted file mode 100644 index 6eecc4a..0000000 --- a/tests/optimizer_test.py +++ /dev/null @@ -1,400 +0,0 @@ -# Optimizer testing modified from bitsandbytes: https://github.com/TimDettmers/bitsandbytes/blob/main/tests/test_optim.py -# bitsandbytes - MIT License - Copyright (c) Facebook, Inc. and its affiliates. - -import inspect -import io -from typing import Optional - -import pytest -import torch -from torch import Tensor -from optimi.utils import MIN_TORCH_2_6 -from optimi import prepare_for_gradient_release, remove_gradient_release - - -@pytest.fixture() -def gpu_device(worker_id, request): - """Map xdist workers to available GPU devices in a round-robin fashion, - supporting CUDA (NVIDIA/ROCm) and XPU (Intel) backends. - Use a single specified GPU if --gpu-id is provided""" - - # Check if specific GPU was requested - specific_gpu = request.config.getoption("--gpu-id") - - # Determine available GPU backend and device count - if torch.cuda.is_available(): - backend = "cuda" - device_count = torch.cuda.device_count() - elif hasattr(torch, 'xpu') and torch.xpu.is_available(): - backend = "xpu" - device_count = torch.xpu.device_count() - else: - # Fallback to cuda for compatibility - backend = "cuda" - device_count = 0 - - if specific_gpu is not None: - return f"{backend}:{specific_gpu}" - - if worker_id == "master": - return backend - - # If no devices available, return default backend - if device_count == 0: - return backend - - # Extract worker number from worker_id (e.g., 'gw6' -> 6) - worker_num = int(worker_id.replace('gw', '')) - - # Map worker to GPU index using modulo to round-robin - gpu_idx = (worker_num - 1) % device_count - return f"{backend}:{gpu_idx}" - - -class MLP(torch.nn.Module): - def __init__(self, input_size, hidden_size, device, dtype): - super().__init__() - self.norm = torch.nn.LayerNorm(input_size, device=device, dtype=dtype) - self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device, dtype=dtype) - self.act = torch.nn.Mish() - self.fc2 = torch.nn.Linear(hidden_size, 1, bias=False, device=device, dtype=dtype) - - def forward(self, x): - x = self.norm(x) - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return x - - -def assert_most_approx_close(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3, max_error_count: int = 0, max_error_rate: float | None = None, name: str = ''): - idx = torch.isclose(a.float(), b.float(), rtol=rtol, atol=atol) - error_count = (idx == 0).sum().item() - if max_error_rate is not None: - if error_count > (a.numel()) * max_error_rate and error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {(a.numel()) * max_error_rate}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) - elif error_count > max_error_count: - print(f"{name}Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) - - -def load_optimizer(params, optimizers, optim_name, key, ftype, skip=False) -> torch.optim.Optimizer: - def update_kwargs(key, argspec, value=True): - if key in argspec.kwonlyargs or key in argspec.args: - kwargs.update({key: value}) - elif value and skip: - pytest.skip(f"Skipping {key} for {optim_name}") - - if optim_name in optimizers: - optimizer = optimizers[optim_name][key]['optim'] - kwargs = optimizers[optim_name][key]['kwargs'] - else: - raise ValueError(f"{optim_name} optimizer not defined") - - argspec = inspect.getfullargspec(optimizer) - update_kwargs('fused', argspec, False) - update_kwargs('foreach', argspec, False) - update_kwargs('triton', argspec, False) - if ftype != '': - update_kwargs(ftype, argspec, True) - - return optimizer(params, **kwargs) - - -def run_optimizer(optimizers:dict, dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, - ftype:str, device:torch.device, buffer:io.BytesIO, iterations:Optional[int]=None, - any_precision:bool=False, atol_override:Optional[dict[torch.dtype, float]]=None, - rtol_override:Optional[dict[torch.dtype, float]]=None, - max_error_rate_override:Optional[dict[torch.dtype, float]]=None): - if dim1 == 1 and dim2 == 1: - pytest.skip("Skipping 1x1 optimizer test") - - ftype = ftype.replace('_', '') - - if atol_override is None: - atol_override = {} - if rtol_override is None: - rtol_override = {} - if max_error_rate_override is None: - max_error_rate_override = {} - - if iterations is None: - if device == torch.device('cpu'): - iterations = 20 - else: - iterations = 40 - - # allow for a small number of errors on low dimension tests - max_error_count = 2 if device == torch.device('cpu') else 5 - - if dtype == torch.float32: - atol = atol_override.get(torch.float32, 1e-6) - rtol = rtol_override.get(torch.float32, 1e-5) - max_error_rate = max_error_rate_override.get(torch.float32, 0.0005) - elif dtype == torch.bfloat16: - atol = atol_override.get(torch.bfloat16, 1e-3) - rtol = rtol_override.get(torch.bfloat16, 1e-2) - max_error_rate = max_error_rate_override.get(torch.bfloat16, 0.01) - elif dtype == torch.float16: - atol = atol_override.get(torch.float16, 1e-4) - rtol = rtol_override.get(torch.float16, 1e-3) - max_error_rate = max_error_rate_override.get(torch.float16, 0.01) - - # Create MLP models instead of simple parameters - m1 = MLP(dim1, dim2, device=device, dtype=dtype) - m2 = MLP(dim1, dim2, device=device, dtype=dtype) - m2.load_state_dict(m1.state_dict()) - - # Convert model parameters to float for non-any_precision testing - if not any_precision and dtype != torch.float32: - for p in m1.parameters(): - p.data = p.data.float() - - torch_optimizer = load_optimizer(m1.parameters(), optimizers, optim_name, 0, ftype) - optimi_optimizer = load_optimizer(m2.parameters(), optimizers, optim_name, 1, ftype, skip=True) - - bs = 1 if device.type == "cpu" else 32 - - for i in range(iterations): - # Training loop with input/target generation - input1 = torch.randn(bs, dim1, device=device, dtype=dtype) - input2 = input1.detach().clone() - target1 = torch.randn(bs, 1, device=device, dtype=dtype) - target2 = target1.detach().clone() - - # Convert model parameters to float for non-any_precision testing - if not any_precision and dtype != torch.float32: - input1 = input1.float() - target1 = target1.float() - - # Forward pass - output1 = m1(input1) - output2 = m2(input2) - - # Loss calculation - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - - # Backward pass - loss1.backward() - loss2.backward() - - # Optimizer step - optimi_optimizer.step() - torch_optimizer.step() - - # Zero gradients - optimi_optimizer.zero_grad() - torch_optimizer.zero_grad() - - # Compare model weights - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, atol=atol, rtol=rtol, - max_error_count=max_error_count, max_error_rate=max_error_rate, - name='fc1: ') - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, atol=atol, rtol=rtol, - max_error_count=max_error_count, max_error_rate=max_error_rate, - name='fc2: ') - - # # Test state_dict saving and loading periodically - if i % (iterations // 10) == 0 and i > 0: - # Save optimizer state - torch.save(optimi_optimizer.state_dict(), buffer) - buffer.seek(0) - # Load checkpoint - ckpt = torch.load(buffer, weights_only=True) - # Recreate optimizer and load its state - optimi_optimizer = load_optimizer(m2.parameters(), optimizers, optim_name, 1, ftype) - optimi_optimizer.load_state_dict(ckpt) - # Clear buffer - buffer.seek(0) - buffer.truncate(0) - - # Verify models are still aligned after state_dict loading - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, atol=atol, rtol=rtol, - max_error_count=max_error_count, max_error_rate=max_error_rate, - name='fc1 after load: ') - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, atol=atol, rtol=rtol, - max_error_count=max_error_count, max_error_rate=max_error_rate, - name='fc2 after load: ') - - -def gradient_release(optimizers:dict, dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, - ftype:str, device:torch.device, iterations:int=40, framework_opt_step:bool=False, - atol_override:Optional[dict[torch.dtype, float]]=None, - rtol_override:Optional[dict[torch.dtype, float]]=None, - max_error_rate_override:Optional[dict[torch.dtype, float]]=None): - def optimizer_hook(parameter) -> None: - torch_optimizers[parameter].step() - torch_optimizers[parameter].zero_grad() - - # Since Lion & Adan can have noisy updates, allow up to 12 errors - max_error_count = 12 - - if atol_override is None: - atol_override = {} - if rtol_override is None: - rtol_override = {} - if max_error_rate_override is None: - max_error_rate_override = {} - - if dtype == torch.float32: - atol = atol_override.get(torch.float32, 2e-6) - rtol = rtol_override.get(torch.float32, 1e-5) - elif dtype == torch.bfloat16: - atol = atol_override.get(torch.bfloat16, 2e-3) - rtol = rtol_override.get(torch.bfloat16, 1e-2) - elif dtype == torch.float16: - atol = atol_override.get(torch.float16, 2e-4) - rtol = rtol_override.get(torch.float16, 1e-3) - - m1 = MLP(dim1, dim2, device=device, dtype=dtype) - m2 = MLP(dim1, dim2, device=device, dtype=dtype) - m3 = MLP(dim1, dim2, device=device, dtype=dtype) - m2.load_state_dict(m1.state_dict()) - m3.load_state_dict(m1.state_dict()) - - regular_optimizer = load_optimizer(m1.parameters(), optimizers, optim_name, 0, ftype) - - - # PyTorch Method: taken from https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - torch_optimizers = {p: load_optimizer([p], optimizers, optim_name, 0, ftype) for p in m2.parameters()} - - pytorch_hooks = [] - for p in m2.parameters(): - pytorch_hooks.append(p.register_post_accumulate_grad_hook(optimizer_hook)) - - - # Optimim Method - # add the gradient release flag to the optimizer kwargs - optimizers[optim_name][1]['kwargs']['gradient_release'] = True - optimi_optimizer = load_optimizer(m3.parameters(), optimizers, optim_name, 1, ftype) - - prepare_for_gradient_release(m3, optimi_optimizer) - bs = 1 if device.type == "cpu" else 32 - - - # Training loop - for i in range(iterations): - input1 = torch.randn(bs, dim1, device=device, dtype=dtype) - input2 = input1.clone() - input3 = input1.clone() - target1 = torch.randn(bs, 1, device=device, dtype=dtype) - target2 = target1.clone() - target3 = target1.clone() - - output1 = m1(input1) - output2 = m2(input2) - output3 = m3(input3) - - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - loss3 = torch.nn.functional.mse_loss(output3, target3) - - loss1.backward() - loss2.backward() - loss3.backward() - - regular_optimizer.step() - regular_optimizer.zero_grad() - - # simulates using an optimi gradient release optimizer in a framework - # where the optimizer step and zero_grad cannot be disabled. - if framework_opt_step: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=rtol, atol=atol, - max_error_count=max_error_count, name='PyTorch-PyTorch: ') - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=rtol, atol=atol, - max_error_count=max_error_count, name='PyTorch-PyTorch: ') - assert_most_approx_close(m1.fc1.weight, m3.fc1.weight, rtol=rtol, atol=atol, - max_error_count=max_error_count, name='PyTorch-Optimi: ') - assert_most_approx_close(m1.fc2.weight, m3.fc2.weight, rtol=rtol, atol=atol, - max_error_count=max_error_count, name='PyTorch-Optimi: ') - - for h in pytorch_hooks: - h.remove() - remove_gradient_release(m3) - - -def optimizer_accumulation(optimizers:dict, dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, - ftype:str, device:torch.device, iterations:int=40, framework_opt_step:bool=False, - atol_override:Optional[dict[torch.dtype, float]]=None, - rtol_override:Optional[dict[torch.dtype, float]]=None, - max_error_rate_override:Optional[dict[torch.dtype, float]]=None): - # Since optimizer accumulation approximates gradient accumulation, the tolerances - # compared to normal optimizers are high despite the low number of iterations - max_error_rate = 0.035 - atol, rtol = 1e-2, 1e-2 - - m1 = MLP(dim1, dim2, device=device, dtype=dtype) - m2 = MLP(dim1, dim2, device=device, dtype=dtype) - m2.load_state_dict(m1.state_dict()) - - regular_optimizer = load_optimizer(m1.parameters(), optimizers, optim_name, 0, ftype) - - - # Optimim Method - # add the gradient release flag to the optimizer kwargs - optimizers[optim_name][1]['kwargs']['gradient_release'] = True - optimi_optimizer = load_optimizer(m2.parameters(), optimizers, optim_name, 1, ftype) - - prepare_for_gradient_release(m2, optimi_optimizer) - - gradient_accumulation_steps = 4 - bs = 1 if device.type == "cpu" else 32 - - # Training loop - for i in range(iterations): - input1 = torch.randn(bs, dim1, device=device, dtype=dtype) - input2 = input1.clone() - target1 = torch.randn(bs, 1, device=device, dtype=dtype) - target2 = target1.clone() - - optimi_optimizer.optimizer_accumulation = (i+1) % gradient_accumulation_steps != 0 - - output1 = m1(input1) - output2 = m2(input2) - - loss1 = torch.nn.functional.mse_loss(output1, target1) - loss2 = torch.nn.functional.mse_loss(output2, target2) - - loss1.backward() - loss2.backward() - - if not optimi_optimizer.optimizer_accumulation: - regular_optimizer.step() - regular_optimizer.zero_grad() - - # simulates using an optimi gradient release optimizer in a framework - # where the optimizer step and zero_grad cannot be disabled. - if framework_opt_step: - optimi_optimizer.step() - optimi_optimizer.zero_grad() - - # unlike other tests, compare that the weights are in the same approximate range at the end of training - assert_most_approx_close(m1.fc1.weight, m2.fc1.weight, rtol=rtol, atol=atol, max_error_rate=max_error_rate) - assert_most_approx_close(m1.fc2.weight, m2.fc2.weight, rtol=rtol, atol=atol, max_error_rate=max_error_rate) - - remove_gradient_release(m2) - - -buffer = io.BytesIO() - - -cpu_dim1 = [64] -cpu_dim2 = [64, 128] -cpu_dtype = [torch.float32] -cpu_ftype = ['', '_foreach'] - - -gpu_dim1 = [256] -gpu_dim2 = [256, 512, 1024, 2048] -gpu_dtype = [torch.float32, torch.bfloat16] -gpu_ftype = ['', '_foreach'] + (['_triton'] if MIN_TORCH_2_6 else []) - -gr_dim1 = [128] -gr_dim2 = [256, 1024] -gr_dtype = [torch.float32] -gr_ftype = [''] + (['_triton'] if MIN_TORCH_2_6 else []) \ No newline at end of file diff --git a/tests/radam_test.py b/tests/radam_test.py deleted file mode 100644 index 6012d85..0000000 --- a/tests/radam_test.py +++ /dev/null @@ -1,75 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from packaging.version import parse - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - -# PyTorch's RAdam adds epsilon before debiasing V while Optimi debases before. -# RAdam tests with a smaller epsilon then other optimizers to prevent numerical divergances. - -optimizers = {} - -optimizers["radam"] = ({'optim':torch.optim.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0)}, - {'optim':optimi.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0)}) - -optimizers["radam_l2"] = ({'optim':torch.optim.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-2)}, - {'optim':optimi.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-2, decouple_wd=False)}) - -if parse(torch.__version__) >= parse("2.2"): - optimizers["radamw"] = ({'optim':torch.optim.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-2, decoupled_weight_decay=True)}, - {'optim':optimi.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-2, decouple_wd=True)}) - - optimizers["radam_dlr"] = ({'optim':torch.optim.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-2, decoupled_weight_decay=True)}, - {'optim':optimi.RAdam, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=1e-5, decouple_lr=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.radam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.radam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer, max_error_rate_override={torch.float32: 0.001}) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.radam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5, max_error_rate_override={torch.float32: 0.001}) - - -@pytest.mark.gpu -@pytest.mark.radam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - if optim_name in ["radam_l2"]: - pytest.skip("Skip tests for RAdam with L2 weight decay.") - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5, max_error_rate_override={torch.float32: 0.001}) diff --git a/tests/ranger_test.py b/tests/ranger_test.py deleted file mode 100644 index c38b3fb..0000000 --- a/tests/ranger_test.py +++ /dev/null @@ -1,67 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - -# The reference Ranger adds epsilon before debiasing V while Optimi debases before. -# Ranger tests with a smaller epsilon then other optimizers to prevent numerical divergances. - -optimizers = {} - -optimizers["ranger"] = ({'optim':reference.Ranger, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0)}, - {'optim':optimi.Ranger, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0)}) - -# reference doesn't perform the normal weight decay step, so no wd tests - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.ranger -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.ranger -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - # test ranger longer due to the lookahead step - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.ranger -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - # test ranger longer due to the lookahead step - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - iterations=160, framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.ranger -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) diff --git a/test_new/runner.py b/tests/runner.py similarity index 100% rename from test_new/runner.py rename to tests/runner.py diff --git a/tests/sgd_test.py b/tests/sgd_test.py deleted file mode 100644 index 02df73a..0000000 --- a/tests/sgd_test.py +++ /dev/null @@ -1,77 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - - - -optimizers = {} - -optimizers["sgd"] = ({'optim':torch.optim.SGD, 'kwargs':dict(lr=1e-3, momentum=0, dampening=0, weight_decay=0)}, - {'optim':optimi.SGD, 'kwargs':dict(lr=1e-3, momentum=0, dampening=False, weight_decay=0)}) - -optimizers["sgd_mom"] = ({'optim':torch.optim.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=0, weight_decay=0)}, - {'optim':optimi.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=False, weight_decay=0)}) - -optimizers["sgd_damp"] = ({'optim':torch.optim.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=0)}, - {'optim':optimi.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=True, weight_decay=0, torch_init=True)}) - -optimizers["sgd_l2"] = ({'optim':torch.optim.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=0, weight_decay=1e-2)}, - {'optim':optimi.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=False, weight_decay=1e-2, decouple_wd=False)}) - -optimizers["sgdw_dlr"] = ({'optim':reference.DecoupledSGDW, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=0.9, weight_decay=1e-5)}, - {'optim':optimi.SGD, 'kwargs':dict(lr=1e-3, momentum=0.9, dampening=True, decouple_lr=True, weight_decay=1e-5, torch_init=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.sgd -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.sgd -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.sgd -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.sgd -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - if optim_name in ["sgd", "sgd_l2"]: - pytest.skip("Skip tests for SGD and SGD with L2 weight decay.") - # SGD will error out more often if iterations is the default of 80 - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - iterations=20, framework_opt_step=torch.rand(1).item() > 0.5) diff --git a/tests/stableadam_test.py b/tests/stableadam_test.py deleted file mode 100644 index c5ea260..0000000 --- a/tests/stableadam_test.py +++ /dev/null @@ -1,68 +0,0 @@ -from itertools import product - -import pytest -import torch - -import optimi -from tests import reference - -from tests.optimizer_test import (buffer, run_optimizer, gradient_release, cpu_dim1, cpu_dim2, cpu_dtype, - cpu_ftype, gpu_dim1, gpu_dim2, gpu_dtype, gpu_ftype, gr_dim1, - gr_dim2, gr_dtype, gr_ftype, optimizer_accumulation, gpu_device) - - - -optimizers = {} - -optimizers["stableadam"] = ({'optim':reference.StableAdamWUnfused, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=0, eps=1e-6)}, - {'optim':optimi.StableAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=0, eps=1e-6)}) - -optimizers["stableadam_wd"] = ({'optim':reference.StableAdamWUnfused, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-2, eps=1e-6)}, - {'optim':optimi.StableAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-2, eps=1e-6)}) - -optimizers["stableadam_dlr"] = ({'optim':reference.StableAdamWUnfused, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-2, eps=1e-6)}, - {'optim':optimi.StableAdamW, 'kwargs':dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=1e-5, eps=1e-6, decouple_lr=True)}) - -optimizer_names = [key for key in optimizers.keys()] - - - -cpu_values = list(product(cpu_dim1, cpu_dim2, cpu_dtype, optimizer_names, cpu_ftype)) -cpu_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cpu_values] - -@pytest.mark.cpu -@pytest.mark.stableadam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cpu_values, ids=cpu_names) -def test_optimizer_cpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device('cpu'), buffer) - - - -cuda_values = list(product(gpu_dim1, gpu_dim2, gpu_dtype, optimizer_names, gpu_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.stableadam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_gpu(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - run_optimizer(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), buffer, iterations=80) - - - -cuda_values = list(product(gr_dim1, gr_dim2, gr_dtype, optimizer_names, gr_ftype)) -cuda_names = ["dim1_{}_dim2_{}_dtype_{}_optim_{}{}".format(*vals) for vals in cuda_values] - -@pytest.mark.gpu -@pytest.mark.stableadam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_gradient_release(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - gradient_release(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) - - -@pytest.mark.gpu -@pytest.mark.stableadam -@pytest.mark.parametrize("dim1, dim2, dtype, optim_name, ftype", cuda_values, ids=cuda_names) -def test_optimizer_accumulation(dim1:int, dim2:int, dtype:torch.dtype, optim_name:str, ftype:str, gpu_device:str): - optimizer_accumulation(optimizers, dim1, dim2, dtype, optim_name, ftype, torch.device(gpu_device), - framework_opt_step=torch.rand(1).item() > 0.5) diff --git a/test_new/test_optimizers_unified.py b/tests/test_optimizers.py similarity index 100% rename from test_new/test_optimizers_unified.py rename to tests/test_optimizers.py From 6f2bf315da90f545e323e12275ac72c9951a0b82 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 13:00:55 -0600 Subject: [PATCH 15/16] normalize test names --- tests/{to_low_precision_test.py => test_to_low_precision.py} | 0 tests/{weight_decay_test.py => test_weight_decay.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{to_low_precision_test.py => test_to_low_precision.py} (100%) rename tests/{weight_decay_test.py => test_weight_decay.py} (100%) diff --git a/tests/to_low_precision_test.py b/tests/test_to_low_precision.py similarity index 100% rename from tests/to_low_precision_test.py rename to tests/test_to_low_precision.py diff --git a/tests/weight_decay_test.py b/tests/test_weight_decay.py similarity index 100% rename from tests/weight_decay_test.py rename to tests/test_weight_decay.py From f1a5cc1610d6ee652d4dce2382895356bc372fe5 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 26 Dec 2025 14:20:51 -0600 Subject: [PATCH 16/16] Fix gpu_device fixture --- tests/conftest.py | 2 +- tests/test_optimizers.py | 44 +++++++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index eda353d..f2a0008 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ def gpu_device(worker_id, request): backend = "mps" device_count = 0 else: - raise RuntimeError("No GPU backend available") + pytest.skip("No GPU backend available") if specific_gpu is not None: return torch.device(f"{backend}:{specific_gpu}") diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 66fb8cc..5b2e65f 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -24,13 +24,15 @@ OPTIMIZERS = [pytest.param(c, id=c.name, marks=getattr(pytest.mark, c.optimizer_name)) for c in discover_tests()] # Full dimensions: CPU -> (64,64), (64,128); GPU -> (256,256), (256,512), (256,1024), (256,2048) -FULL_DIMS = [ - pytest.param((DeviceType.cpu, (64, 64)), id="cpu-64x64"), - pytest.param((DeviceType.cpu, (64, 128)), id="cpu-64x128"), - pytest.param((DeviceType.gpu, (256, 256)), id="gpu-256x256"), - pytest.param((DeviceType.gpu, (256, 512)), id="gpu-256x512"), - pytest.param((DeviceType.gpu, (256, 1024)), id="gpu-256x1024"), - pytest.param((DeviceType.gpu, (256, 2048)), id="gpu-256x2048"), +CPU_DIMS = [ + pytest.param((64, 64), id="cpu-64x64"), + pytest.param((64, 128), id="cpu-64x128"), +] +GPU_FULL_DIMS = [ + pytest.param((256, 256), id="gpu-256x256"), + pytest.param((256, 512), id="gpu-256x512"), + pytest.param((256, 1024), id="gpu-256x1024"), + pytest.param((256, 2048), id="gpu-256x2048"), ] # Gradient release and accumulation dims: (128,256) and (128,1024) @@ -93,7 +95,6 @@ def _build_params(test_type: OptTestType) -> list[ParameterSet]: if test_type == OptTestType.normal: device_params = DEVICE_PARAMS dtype_params = DTYPE_PARAMS - dims_params = FULL_DIMS else: device_params = [pytest.param(DeviceType.gpu, marks=pytest.mark.gpu, id=DeviceType.gpu.value)] dtype_params = [pytest.param(torch.float32, marks=pytest.mark.float32, id="float32")] @@ -102,11 +103,11 @@ def _build_params(test_type: OptTestType) -> list[ParameterSet]: params: list[ParameterSet] = [] for opt_param in OPTIMIZERS: for device_param in device_params: + if test_type == OptTestType.normal: + dims_params = GPU_FULL_DIMS if _param_value(device_param) == DeviceType.gpu else CPU_DIMS for dtype_param in dtype_params: for backend_param in BACKEND_PARAMS: for dims_param in dims_params: - if test_type == OptTestType.normal and _param_value(dims_param)[0] != _param_value(device_param): - continue if _should_skip( test_type, _param_value(opt_param), @@ -138,18 +139,23 @@ def _build_params(test_type: OptTestType) -> list[ParameterSet]: return params -@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims_spec", _build_params(OptTestType.normal)) -def test_normal(opttest, device_type, dtype, backend, dims_spec, gpu_device): - _, dims = dims_spec - device = torch.device(gpu_device if device_type == DeviceType.gpu else "cpu") - run_test(opttest, device, dtype, backend, OptTestType.normal, dims=dims) +def _get_device(device_type: DeviceType, request: pytest.FixtureRequest) -> torch.device: + if device_type == DeviceType.gpu: + return torch.device(request.getfixturevalue("gpu_device")) + else: + return torch.device("cpu") + + +@pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.normal)) +def test_normal(opttest, device_type, dtype, backend, dims, request): + run_test(opttest, _get_device(device_type, request), dtype, backend, OptTestType.normal, dims=dims) @pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.gradient_release)) -def test_gradient_release(opttest, device_type, dtype, backend, dims, gpu_device): - run_test(opttest, torch.device(gpu_device), dtype, backend, OptTestType.gradient_release, dims=dims) +def test_gradient_release(opttest, device_type, dtype, backend, dims, request): + run_test(opttest, _get_device(device_type, request), dtype, backend, OptTestType.gradient_release, dims=dims) @pytest.mark.parametrize("opttest, device_type, dtype, backend, dims", _build_params(OptTestType.accumulation)) -def test_accumulation(opttest, device_type, dtype, backend, dims, gpu_device): - run_test(opttest, torch.device(gpu_device), dtype, backend, OptTestType.accumulation, dims=dims) +def test_accumulation(opttest, device_type, dtype, backend, dims, request): + run_test(opttest, _get_device(device_type, request), dtype, backend, OptTestType.accumulation, dims=dims)