Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
51 changes: 51 additions & 0 deletions mlpstorage/submission_checker/checks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod


class BaseCheck(ABC):
"""
A generic check class meant to be inherited by concrete check implementations.
Subclasses must register their check methods into `self.checks`.
"""

def __init__(self, log, path):
self.checks = []
self.log = log
self.path = path
self.name = "base checks"
pass

def run_checks(self):
"""
Execute all registered checks. Returns True if all checks pass, False otherwise.
"""
valid = True
errors = []
for check in self.checks:
try:
v = self.execute(check)
valid &= v
except BaseException:
valid &= False
self.log.error(
"Exception occurred in %s while running %s in %s",
self.path,
check.__name__,
self.__class__.__name__)
return valid

def execute(self, check):
"""Custom execution of a single check method."""
return check()

def __call__(self):
"""Allows the check instance to be called like a function."""
self.log.info("Starting %s for: %s", self.name, self.path)
valid = self.run_checks()
if valid:
self.log.info("All %s checks passed for: %s", self.name, self.path)
else:
self.log.error(
"Some %s Checks failed for: %s",
self.name,
self.path)
return valid
282 changes: 282 additions & 0 deletions mlpstorage/submission_checker/checks/checkpointing_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@

from .base import BaseCheck
from ..constants import *
from ..configuration.configuration import Config
from ..loader import SubmissionLogs

import os
import re


class CheckpointingCheck(BaseCheck):
"""
A check class for validating checkpointing parameters and related properties.
Inherits from BaseCheck and receives a config and loader instance.
"""

def __init__(self, log, config: Config, submissions_logs: SubmissionLogs):
"""
Initialize CheckpointingChecks with configuration and loader.

Args:
config: A Config instance containing submission configuration.
loader: A SubmissionLogs instance for accessing submission logs.
"""
# Call parent constructor with the loader's log and submission path
super().__init__(log=log, path=submissions_logs.loader_metadata.folder)
self.config = config
self.submissions_logs = submissions_logs.checkpoint_files
self.name = "checkpointing checks"
self.checks = []
self.checkpointing_path = self.path
self.init_checks()

def init_checks(self):
"""Initialize the list of checks to run."""
self.checks = [
self.checkpoint_data_size_ratio,
self.fsync_verification,
self.model_configuration_req,
self.closed_mpi_processes,
self.closed_accelerators_per_host,
self.aggregate_accelerator_memory,
self.closed_checkpoint_parameters,
self.checkpoint_path_args,
self.subset_run_validation,
]

def checkpoint_data_size_ratio(self):
"""
Verify that checkpoint data written per node > 3x node memory.
"""
valid = True

for summary, metadata, _ in self.submissions_logs:
checkpoint_size_gb = summary.get("metric", {}).get("checkpoint_size_GB", 0)
host_memory_gb = summary.get("host_memory_GB", [0])[0]
num_hosts = summary.get("num_hosts", 1)

if checkpoint_size_gb == 0 or host_memory_gb == 0:
continue

# Data written per node
data_per_node = checkpoint_size_gb / num_hosts
min_required = 3 * host_memory_gb

if data_per_node < min_required:
self.log.warning(
"Checkpoint data per node %.2f GB < 3x memory %.2f GB. "
"Cache flush may be needed.",
data_per_node,
min_required
)

return valid

def fsync_verification(self):
"""
Verify that fsync is enabled in checkpoint configuration.
"""
valid = True

for summary, metadata, _ in self.submissions_logs:
combined_params = metadata.get("combined_params", {})
checkpoint_params = combined_params.get("checkpoint", {})
fsync_enabled = checkpoint_params.get("fsync", False)

if not fsync_enabled:
self.log.error("Checkpoint fsync is not enabled in configuration")
valid = False

return valid

def model_configuration_req(self):
"""
Verify benchmark uses one of the four supported models.
"""
valid = True
allowed_models = {"8b", "70b", "405b", "1t"}

for summary, metadata, _ in self.submissions_logs:
model_name = metadata.get("args", {}).get("model", "").lower()

# Extract just the size part (8b, 70b, etc.)
model_size = re.search(r"(8b|70b|405b|1t)", model_name)

if not model_size or model_size.group(1) not in allowed_models:
self.log.error(
"Invalid model '%s'. Must be one of: %s",
model_name,
allowed_models
)
valid = False

return valid

def closed_mpi_processes(self):
"""
For CLOSED submissions, verify MPI processes match requirements per model.
"""
valid = True

model_process_requirements = {
"8b": 8,
"70b": 64,
"405b": 512,
"1t": 1024
}

for summary, metadata, _ in self.submissions_logs:
verification = metadata.get("verification", "open")

if verification == "closed":
model_name = metadata.get("args", {}).get("model", "").lower()
num_processes = metadata.get("args", {}).get("num_processes", 0)

model_size = re.search(r"(8b|70b|405b|1t)", model_name)
if model_size:
model_key = model_size.group(1)
required_processes = model_process_requirements.get(model_key)

if required_processes and num_processes != required_processes:
self.log.error(
"CLOSED submission with model %s requires %d processes, got %d",
model_key,
required_processes,
num_processes
)
valid = False

return valid

def closed_accelerators_per_host(self):
"""
For CLOSED submissions, verify accelerators per host > 4 and total matches requirement.
"""
valid = True

for summary, metadata, _ in self.submissions_logs:
verification = metadata.get("verification", "open")

if verification == "closed":
num_accelerators = summary.get("num_accelerators", 0)
num_hosts = summary.get("num_hosts", 1)

accelerators_per_host = num_accelerators / num_hosts if num_hosts > 0 else 0

if accelerators_per_host <= 4:
self.log.error(
"CLOSED submission: accelerators per host %.2f must be > 4",
accelerators_per_host
)
valid = False

return valid

def aggregate_accelerator_memory(self):
"""
Verify total accelerator memory >= checkpoint size.
H100 has 80GB per accelerator.
"""
valid = True
ACCELERATOR_MEMORY_GB = 80 # H100

for summary, metadata, _ in self.submissions_logs:
checkpoint_size_gb = summary.get("metric", {}).get("checkpoint_size_GB", 0)
num_accelerators = summary.get("num_accelerators", 0)

total_accelerator_memory = num_accelerators * ACCELERATOR_MEMORY_GB

if total_accelerator_memory < checkpoint_size_gb:
self.log.error(
"Aggregate accelerator memory %.2f GB < checkpoint size %.2f GB",
total_accelerator_memory,
checkpoint_size_gb
)
valid = False

return valid

def closed_checkpoint_parameters(self):
"""
For CLOSED submissions, verify only allowed parameters are modified.
"""
valid = True

allowed_params = {
"checkpoint.checkpoint_folder"
}

for summary, metadata, _ in self.submissions_logs:
verification = metadata.get("verification", "open")

if verification == "closed":
params_dict = metadata.get("params_dict", {})

for param_key in params_dict.keys():
if param_key not in allowed_params:
self.log.error(
"CLOSED submission modifies disallowed parameter: %s",
param_key
)
valid = False

return valid

def checkpoint_path_args(self):
"""
Verify checkpoint folder and output paths are set and different.
"""
valid = True

for summary, metadata, _ in self.submissions_logs:
args = metadata.get("args", {})
checkpoint_folder = args.get("checkpoint_folder")
results_dir = args.get("results_dir")

if not checkpoint_folder:
self.log.error("checkpoint_folder not set in arguments")
valid = False

if not results_dir:
self.log.error("results_dir not set in arguments")
valid = False

if checkpoint_folder and results_dir and checkpoint_folder == results_dir:
self.log.error(
"checkpoint_folder and results_dir must be different: both are %s",
checkpoint_folder
)
valid = False

return valid

def subset_run_validation(self):
"""
For subset runs, verify exactly 8 accelerators and not 8B model.
"""
valid = True

for summary, metadata, _ in self.submissions_logs:
params_dict = metadata.get("params_dict", {})
checkpoint_mode = params_dict.get("checkpoint.mode", "")

if checkpoint_mode == "subset":
num_accelerators = summary.get("num_accelerators", 0)
model_name = metadata.get("args", {}).get("model", "").lower()

if num_accelerators != 8:
self.log.error(
"Subset run requires exactly 8 accelerators, got %d",
num_accelerators
)
valid = False

if "8b" in model_name:
self.log.error(
"Subset run cannot use 8B model: %s",
model_name
)
valid = False

return valid
Loading