Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions aws/lambda/pytorch-auto-revert/WORKFLOW_DISPATCH_FILTERS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Workflow Dispatch Filters

## Overview

PyTorch CI workflows (`trunk.yml`, `pull.yml`) support optional filtering inputs for `workflow_dispatch` events. This allows autorevert to re-run only specific failed jobs and tests instead of the full CI suite.

## Workflow Dispatch Inputs

| Input | Type | Description |
|-------|------|-------------|
| `jobs-to-include` | string | Space-separated list of job display names to run (empty = all jobs) |
| `tests-to-include` | string | Space-separated list of test modules to run (empty = all tests) |

## Filter Value Derivation

Filter values are derived from Signal metadata during signal extraction.

### Job Names (`jobs-to-include`)

Derived from `Signal.job_base_name`. Job names follow two patterns:

| Pattern | Example | Filter Value |
|---------|---------|--------------|
| With ` / ` separator | `linux-jammy-cuda12.8-py3.10-gcc11 / test` | `linux-jammy-cuda12.8-py3.10-gcc11` |
| Without separator | `inductor-build` | `inductor-build` |

**More examples:**
- `linux-jammy-cuda12.8-py3.10-gcc11 / build` → `linux-jammy-cuda12.8-py3.10-gcc11`
- `linux-jammy-py3.10-gcc11` → `linux-jammy-py3.10-gcc11`
- `job-filter` → `job-filter`
- `get-label-type` → `get-label-type`

### Test Modules (`tests-to-include`)

Derived from `Signal.test_module` (set during signal extraction from test file path, without `.py` extension).

**Examples:**
- `test_torch`
- `test_nn`
- `distributed/elastic/multiprocessing/api_test`
- `distributed/test_c10d`

## Input Format Rules

### `jobs-to-include`
- Space-separated exact job **display names**
- Case-sensitive, must match exactly
- Examples:
- Build/test jobs: `"linux-jammy-cuda12.8-py3.10-gcc11 linux-jammy-py3.10-gcc11"`
- Standalone jobs: `"inductor-build job-filter get-label-type"`

### `tests-to-include`
- Space-separated test module paths (no `.py` extension)
- Module-level only (no `::TestClass::test_method`)
- Example: `"test_torch test_nn distributed/elastic/multiprocessing/api_test"`

## Behavior Notes

1. **Empty inputs** = run all jobs/tests (normal CI behavior)
2. **Filtered dispatch** = only matching jobs run; within those jobs, only matching tests run
3. **Test sharding** preserved - distributed tests still run on distributed shards
4. **TD compatibility** - TD is disabled for filtered test runs; only specified tests run
5. **Workflow support detection** - autorevert parses workflow YAML to check if inputs are supported before dispatch
45 changes: 43 additions & 2 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .github_client_helper import GHClientFactory
from .testers.autorevert_v2 import autorevert_v2
from .testers.hud import render_hud_html_from_clickhouse, write_hud_html_from_cli
from .testers.restart_checker import workflow_restart_checker
from .testers.restart_checker import dispatch_workflow_restart, workflow_restart_checker
from .utils import parse_datetime, RestartAction, RetryWithBackoff, RevertAction


Expand Down Expand Up @@ -356,6 +356,35 @@ def get_opts(default_config: DefaultConfig) -> argparse.Namespace:
help="If no `--commit` specified, look back days for bulk query (default: 7)",
)

# restart-workflow subcommand: dispatch a workflow restart with optional filters
restart_workflow_parser = subparsers.add_parser(
"restart-workflow",
help="Dispatch a workflow restart with optional job/test filters",
)
restart_workflow_parser.add_argument(
"workflow",
help="Workflow name (e.g., trunk or trunk.yml)",
)
restart_workflow_parser.add_argument(
"commit",
help="Commit SHA to restart",
)
restart_workflow_parser.add_argument(
"--jobs",
default=None,
help="Space-separated job display names to filter (e.g., 'linux-jammy-cuda12.8-py3.10-gcc11')",
)
restart_workflow_parser.add_argument(
"--tests",
default=None,
help="Space-separated test module paths to filter (e.g., 'test_torch distributed/test_c10d')",
)
restart_workflow_parser.add_argument(
"--repo-full-name",
default=default_config.repo_full_name,
help="Repository in owner/repo format (default: pytorch/pytorch)",
)

# hud subcommand: generate local HTML report for signals/detections
hud_parser = subparsers.add_parser(
"hud", help="Render HUD HTML from a logged autorevert run state"
Expand Down Expand Up @@ -437,10 +466,13 @@ def _get(attr: str, default=None):
log_level=_get("log_level", DEFAULT_LOG_LEVEL),
dry_run=_get("dry_run", False),
subcommand=_get("subcommand", "autorevert-checker"),
# Subcommand: workflow-restart-checker
# Subcommand: workflow-restart-checker and restart-workflow
workflow=_get("workflow", None),
commit=_get("commit", None),
days=_get("days", DEFAULT_WORKFLOW_RESTART_DAYS),
# Subcommand: restart-workflow (filter inputs)
jobs=_get("jobs", None),
tests=_get("tests", None),
# Subcommand: hud
timestamp=_get("timestamp", None),
hud_html=_get("hud_html", None),
Expand Down Expand Up @@ -693,6 +725,15 @@ def main_run(
workflow_restart_checker(
config.workflow, commit=config.commit, days=config.days
)
elif config.subcommand == "restart-workflow":
dispatch_workflow_restart(
workflow=config.workflow,
commit=config.commit,
jobs=config.jobs,
tests=config.tests,
repo=config.repo_full_name,
dry_run=config.dry_run,
)
elif config.subcommand == "hud":
out_path: Optional[str] = (
None if config.hud_html is HUD_HTML_NO_VALUE_FLAG else config.hud_html
Expand Down
8 changes: 7 additions & 1 deletion aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,18 @@ class AutorevertConfig:
subcommand: str = "autorevert-checker"

# -------------------------------------------------------------------------
# Subcommand: workflow-restart-checker
# Subcommand: workflow-restart-checker and restart-workflow
# -------------------------------------------------------------------------
workflow: Optional[str] = None
commit: Optional[str] = None
days: int = DEFAULT_WORKFLOW_RESTART_DAYS

# -------------------------------------------------------------------------
# Subcommand: restart-workflow (filter inputs)
# -------------------------------------------------------------------------
jobs: Optional[str] = None # Space-separated job display names
tests: Optional[str] = None # Space-separated test module paths

# -------------------------------------------------------------------------
# Subcommand: hud
# -------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class Signal:
- workflow_name: source workflow this signal is derived from
- commits: newest → older list of SignalCommit objects for this signal
- job_base_name: optional job base name for job-level signals (recorded when signal is created)
- test_module: optional test module path for test-level signals (e.g., "test_torch" or "distributed/test_c10d")
"""

def __init__(
Expand All @@ -296,13 +297,16 @@ def __init__(
workflow_name: str,
commits: List[SignalCommit],
job_base_name: Optional[str] = None,
test_module: Optional[str] = None,
source: SignalSource = SignalSource.TEST,
):
self.key = key
self.workflow_name = workflow_name
# commits are ordered from newest to oldest
self.commits = commits
self.job_base_name = job_base_name
# Test module path without .py extension (e.g., "test_torch", "distributed/test_c10d")
self.test_module = test_module
# Track the origin of the signal (test-track or job-track).
self.source = source

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple, Union

import github

Expand Down Expand Up @@ -40,10 +40,33 @@ class SignalMetadata:
workflow_name: str
key: str
job_base_name: Optional[str] = None
test_module: Optional[str] = None
wf_run_id: Optional[int] = None
job_id: Optional[int] = None


def _derive_job_filter(job_base_name: Optional[str]) -> Optional[str]:
"""Extract job display name for jobs-to-include filter.

For jobs with " / " separator (e.g., "linux-jammy-cuda12.8 / test"),
returns the prefix before the separator.

For jobs without separator (e.g., "linux-jammy-py3.10-gcc11", "inductor-build"),
returns the full job_base_name as the display name.

Examples:
"linux-jammy-cuda12.8 / test" -> "linux-jammy-cuda12.8"
"linux-jammy-py3.10-gcc11" -> "linux-jammy-py3.10-gcc11"
"inductor-build" -> "inductor-build"
"job-filter" -> "job-filter"
"""
if not job_base_name:
return None
if " / " in job_base_name:
return job_base_name.split(" / ")[0].strip()
return job_base_name.strip()


@dataclass(frozen=True)
class ActionGroup:
"""A coalesced action candidate built from one or more signals.
Expand All @@ -52,12 +75,16 @@ class ActionGroup:
- commit_sha: target commit
- workflow_target: workflow to restart (restart only); None/'' for revert
- sources: contributing signals (workflow_name, key, outcome)
- jobs_to_include: job display names to filter for restart (empty = all jobs)
- tests_to_include: test module paths to filter for restart (empty = all tests)
"""

type: str # 'revert' | 'restart'
commit_sha: str
workflow_target: str | None # restart-only; None/'' for revert
sources: List[SignalMetadata]
jobs_to_include: FrozenSet[str] = frozenset()
tests_to_include: FrozenSet[str] = frozenset()


class ActionLogger:
Expand Down Expand Up @@ -229,6 +256,7 @@ def group_actions(
workflow_name=sig.workflow_name,
key=sig.key,
job_base_name=sig.job_base_name,
test_module=sig.test_module,
wf_run_id=wf_run_id,
job_id=job_id,
)
Expand All @@ -251,12 +279,18 @@ def group_actions(
)
)
for (wf, sha), sources in restart_map.items():
jobs = [_derive_job_filter(src.job_base_name) for src in sources]

groups.append(
ActionGroup(
type="restart",
commit_sha=sha,
workflow_target=wf,
sources=sources,
jobs_to_include=frozenset(j for j in jobs if j is not None),
tests_to_include=frozenset(
src.test_module for src in sources if src.test_module
),
)
)
return groups
Expand All @@ -279,6 +313,8 @@ def execute(self, group: ActionGroup, ctx: RunContext) -> bool:
commit_sha=group.commit_sha,
sources=group.sources,
ctx=ctx,
jobs_to_include=group.jobs_to_include,
tests_to_include=group.tests_to_include,
)
return False

Expand Down Expand Up @@ -330,6 +366,8 @@ def execute_restart(
commit_sha: str,
sources: List[SignalMetadata],
ctx: RunContext,
jobs_to_include: FrozenSet[str] = frozenset(),
tests_to_include: FrozenSet[str] = frozenset(),
) -> bool:
"""Dispatch a workflow restart subject to pacing, cap, and backoff; always logs the event."""
if ctx.restart_action == RestartAction.SKIP:
Expand Down Expand Up @@ -374,18 +412,32 @@ def execute_restart(
)
return False

notes = ""
# Build notes incrementally
notes_parts: list[str] = []
if jobs_to_include:
notes_parts.append(f"jobs_filter={','.join(jobs_to_include)}")
if tests_to_include:
notes_parts.append(f"tests_filter={','.join(tests_to_include)}")

ok = True
if not dry_run:
try:
self._restart.restart_workflow(workflow_target, commit_sha)
self._restart.restart_workflow(
workflow_target,
commit_sha,
jobs_to_include=jobs_to_include,
tests_to_include=tests_to_include,
)
except Exception as exc:
ok = False
notes = str(exc) or repr(exc)
notes_parts.append(str(exc) or repr(exc))
logging.exception(
"[v2][action] restart for sha %s: exception while dispatching",
commit_sha[:8],
)

notes = "; ".join(notes_parts)

self._logger.insert_event(
repo=ctx.repo_full_name,
ts=ctx.ts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
workflow_name=s.workflow_name,
commits=new_commits,
job_base_name=s.job_base_name,
test_module=s.test_module,
source=s.source,
)
)
Expand Down Expand Up @@ -218,6 +219,7 @@ def _inject_pending_workflow_events(
workflow_name=s.workflow_name,
commits=new_commits,
job_base_name=s.job_base_name,
test_module=s.test_module,
source=s.source,
)
)
Expand Down Expand Up @@ -434,12 +436,17 @@ def _build_test_signals(
)

if has_any_events:
# Extract test module from test_id (format: "file.py::test_name")
# Result: "file" or "path/to/file" without .py extension
test_module = test_id.split("::")[0].replace(".py", "")

signals.append(
Signal(
key=test_id,
workflow_name=wf_name,
commits=commit_objs,
job_base_name=str(job_base_name),
test_module=test_module,
source=SignalSource.TEST,
)
)
Expand Down
Loading