diff --git a/aws/lambda/pytorch-auto-revert/WORKFLOW_DISPATCH_FILTERS.md b/aws/lambda/pytorch-auto-revert/WORKFLOW_DISPATCH_FILTERS.md new file mode 100644 index 0000000000..d59f31f8f6 --- /dev/null +++ b/aws/lambda/pytorch-auto-revert/WORKFLOW_DISPATCH_FILTERS.md @@ -0,0 +1,63 @@ +# Workflow Dispatch Filters + +## Overview + +PyTorch CI workflows (`trunk.yml`, `pull.yml`) support optional filtering inputs for `workflow_dispatch` events. This allows autorevert to re-run only specific failed jobs and tests instead of the full CI suite. + +## Workflow Dispatch Inputs + +| Input | Type | Description | +|-------|------|-------------| +| `jobs-to-include` | string | Space-separated list of job display names to run (empty = all jobs) | +| `tests-to-include` | string | Space-separated list of test modules to run (empty = all tests) | + +## Filter Value Derivation + +Filter values are derived from Signal metadata during signal extraction. + +### Job Names (`jobs-to-include`) + +Derived from `Signal.job_base_name`. Job names follow two patterns: + +| Pattern | Example | Filter Value | +|---------|---------|--------------| +| With ` / ` separator | `linux-jammy-cuda12.8-py3.10-gcc11 / test` | `linux-jammy-cuda12.8-py3.10-gcc11` | +| Without separator | `inductor-build` | `inductor-build` | + +**More examples:** +- `linux-jammy-cuda12.8-py3.10-gcc11 / build` → `linux-jammy-cuda12.8-py3.10-gcc11` +- `linux-jammy-py3.10-gcc11` → `linux-jammy-py3.10-gcc11` +- `job-filter` → `job-filter` +- `get-label-type` → `get-label-type` + +### Test Modules (`tests-to-include`) + +Derived from `Signal.test_module` (set during signal extraction from test file path, without `.py` extension). + +**Examples:** +- `test_torch` +- `test_nn` +- `distributed/elastic/multiprocessing/api_test` +- `distributed/test_c10d` + +## Input Format Rules + +### `jobs-to-include` +- Space-separated exact job **display names** +- Case-sensitive, must match exactly +- Examples: + - Build/test jobs: `"linux-jammy-cuda12.8-py3.10-gcc11 linux-jammy-py3.10-gcc11"` + - Standalone jobs: `"inductor-build job-filter get-label-type"` + +### `tests-to-include` +- Space-separated test module paths (no `.py` extension) +- Module-level only (no `::TestClass::test_method`) +- Example: `"test_torch test_nn distributed/elastic/multiprocessing/api_test"` + +## Behavior Notes + +1. **Empty inputs** = run all jobs/tests (normal CI behavior) +2. **Filtered dispatch** = only matching jobs run; within those jobs, only matching tests run +3. **Test sharding** preserved - distributed tests still run on distributed shards +4. **TD compatibility** - TD is disabled for filtered test runs; only specified tests run +5. **Workflow support detection** - autorevert parses workflow YAML to check if inputs are supported before dispatch diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py index 9eb14b2df2..a9dcf0a9d3 100755 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py @@ -29,7 +29,7 @@ from .github_client_helper import GHClientFactory from .testers.autorevert_v2 import autorevert_v2 from .testers.hud import render_hud_html_from_clickhouse, write_hud_html_from_cli -from .testers.restart_checker import workflow_restart_checker +from .testers.restart_checker import dispatch_workflow_restart, workflow_restart_checker from .utils import parse_datetime, RestartAction, RetryWithBackoff, RevertAction @@ -356,6 +356,35 @@ def get_opts(default_config: DefaultConfig) -> argparse.Namespace: help="If no `--commit` specified, look back days for bulk query (default: 7)", ) + # restart-workflow subcommand: dispatch a workflow restart with optional filters + restart_workflow_parser = subparsers.add_parser( + "restart-workflow", + help="Dispatch a workflow restart with optional job/test filters", + ) + restart_workflow_parser.add_argument( + "workflow", + help="Workflow name (e.g., trunk or trunk.yml)", + ) + restart_workflow_parser.add_argument( + "commit", + help="Commit SHA to restart", + ) + restart_workflow_parser.add_argument( + "--jobs", + default=None, + help="Space-separated job display names to filter (e.g., 'linux-jammy-cuda12.8-py3.10-gcc11')", + ) + restart_workflow_parser.add_argument( + "--tests", + default=None, + help="Space-separated test module paths to filter (e.g., 'test_torch distributed/test_c10d')", + ) + restart_workflow_parser.add_argument( + "--repo-full-name", + default=default_config.repo_full_name, + help="Repository in owner/repo format (default: pytorch/pytorch)", + ) + # hud subcommand: generate local HTML report for signals/detections hud_parser = subparsers.add_parser( "hud", help="Render HUD HTML from a logged autorevert run state" @@ -437,10 +466,13 @@ def _get(attr: str, default=None): log_level=_get("log_level", DEFAULT_LOG_LEVEL), dry_run=_get("dry_run", False), subcommand=_get("subcommand", "autorevert-checker"), - # Subcommand: workflow-restart-checker + # Subcommand: workflow-restart-checker and restart-workflow workflow=_get("workflow", None), commit=_get("commit", None), days=_get("days", DEFAULT_WORKFLOW_RESTART_DAYS), + # Subcommand: restart-workflow (filter inputs) + jobs=_get("jobs", None), + tests=_get("tests", None), # Subcommand: hud timestamp=_get("timestamp", None), hud_html=_get("hud_html", None), @@ -693,6 +725,15 @@ def main_run( workflow_restart_checker( config.workflow, commit=config.commit, days=config.days ) + elif config.subcommand == "restart-workflow": + dispatch_workflow_restart( + workflow=config.workflow, + commit=config.commit, + jobs=config.jobs, + tests=config.tests, + repo=config.repo_full_name, + dry_run=config.dry_run, + ) elif config.subcommand == "hud": out_path: Optional[str] = ( None if config.hud_html is HUD_HTML_NO_VALUE_FLAG else config.hud_html diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py index 20bf30b27f..af918f4af7 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/config.py @@ -82,12 +82,18 @@ class AutorevertConfig: subcommand: str = "autorevert-checker" # ------------------------------------------------------------------------- - # Subcommand: workflow-restart-checker + # Subcommand: workflow-restart-checker and restart-workflow # ------------------------------------------------------------------------- workflow: Optional[str] = None commit: Optional[str] = None days: int = DEFAULT_WORKFLOW_RESTART_DAYS + # ------------------------------------------------------------------------- + # Subcommand: restart-workflow (filter inputs) + # ------------------------------------------------------------------------- + jobs: Optional[str] = None # Space-separated job display names + tests: Optional[str] = None # Space-separated test module paths + # ------------------------------------------------------------------------- # Subcommand: hud # ------------------------------------------------------------------------- diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py index cd67868680..092be7de8e 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py @@ -288,6 +288,7 @@ class Signal: - workflow_name: source workflow this signal is derived from - commits: newest → older list of SignalCommit objects for this signal - job_base_name: optional job base name for job-level signals (recorded when signal is created) + - test_module: optional test module path for test-level signals (e.g., "test_torch" or "distributed/test_c10d") """ def __init__( @@ -296,6 +297,7 @@ def __init__( workflow_name: str, commits: List[SignalCommit], job_base_name: Optional[str] = None, + test_module: Optional[str] = None, source: SignalSource = SignalSource.TEST, ): self.key = key @@ -303,6 +305,8 @@ def __init__( # commits are ordered from newest to oldest self.commits = commits self.job_base_name = job_base_name + # Test module path without .py extension (e.g., "test_torch", "distributed/test_c10d") + self.test_module = test_module # Track the origin of the signal (test-track or job-track). self.source = source diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py index a191aad9a7..62368bfa8d 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from enum import Enum -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple, Union import github @@ -40,10 +40,33 @@ class SignalMetadata: workflow_name: str key: str job_base_name: Optional[str] = None + test_module: Optional[str] = None wf_run_id: Optional[int] = None job_id: Optional[int] = None +def _derive_job_filter(job_base_name: Optional[str]) -> Optional[str]: + """Extract job display name for jobs-to-include filter. + + For jobs with " / " separator (e.g., "linux-jammy-cuda12.8 / test"), + returns the prefix before the separator. + + For jobs without separator (e.g., "linux-jammy-py3.10-gcc11", "inductor-build"), + returns the full job_base_name as the display name. + + Examples: + "linux-jammy-cuda12.8 / test" -> "linux-jammy-cuda12.8" + "linux-jammy-py3.10-gcc11" -> "linux-jammy-py3.10-gcc11" + "inductor-build" -> "inductor-build" + "job-filter" -> "job-filter" + """ + if not job_base_name: + return None + if " / " in job_base_name: + return job_base_name.split(" / ")[0].strip() + return job_base_name.strip() + + @dataclass(frozen=True) class ActionGroup: """A coalesced action candidate built from one or more signals. @@ -52,12 +75,16 @@ class ActionGroup: - commit_sha: target commit - workflow_target: workflow to restart (restart only); None/'' for revert - sources: contributing signals (workflow_name, key, outcome) + - jobs_to_include: job display names to filter for restart (empty = all jobs) + - tests_to_include: test module paths to filter for restart (empty = all tests) """ type: str # 'revert' | 'restart' commit_sha: str workflow_target: str | None # restart-only; None/'' for revert sources: List[SignalMetadata] + jobs_to_include: FrozenSet[str] = frozenset() + tests_to_include: FrozenSet[str] = frozenset() class ActionLogger: @@ -229,6 +256,7 @@ def group_actions( workflow_name=sig.workflow_name, key=sig.key, job_base_name=sig.job_base_name, + test_module=sig.test_module, wf_run_id=wf_run_id, job_id=job_id, ) @@ -251,12 +279,18 @@ def group_actions( ) ) for (wf, sha), sources in restart_map.items(): + jobs = [_derive_job_filter(src.job_base_name) for src in sources] + groups.append( ActionGroup( type="restart", commit_sha=sha, workflow_target=wf, sources=sources, + jobs_to_include=frozenset(j for j in jobs if j is not None), + tests_to_include=frozenset( + src.test_module for src in sources if src.test_module + ), ) ) return groups @@ -279,6 +313,8 @@ def execute(self, group: ActionGroup, ctx: RunContext) -> bool: commit_sha=group.commit_sha, sources=group.sources, ctx=ctx, + jobs_to_include=group.jobs_to_include, + tests_to_include=group.tests_to_include, ) return False @@ -330,6 +366,8 @@ def execute_restart( commit_sha: str, sources: List[SignalMetadata], ctx: RunContext, + jobs_to_include: FrozenSet[str] = frozenset(), + tests_to_include: FrozenSet[str] = frozenset(), ) -> bool: """Dispatch a workflow restart subject to pacing, cap, and backoff; always logs the event.""" if ctx.restart_action == RestartAction.SKIP: @@ -374,18 +412,32 @@ def execute_restart( ) return False - notes = "" + # Build notes incrementally + notes_parts: list[str] = [] + if jobs_to_include: + notes_parts.append(f"jobs_filter={','.join(jobs_to_include)}") + if tests_to_include: + notes_parts.append(f"tests_filter={','.join(tests_to_include)}") + ok = True if not dry_run: try: - self._restart.restart_workflow(workflow_target, commit_sha) + self._restart.restart_workflow( + workflow_target, + commit_sha, + jobs_to_include=jobs_to_include, + tests_to_include=tests_to_include, + ) except Exception as exc: ok = False - notes = str(exc) or repr(exc) + notes_parts.append(str(exc) or repr(exc)) logging.exception( "[v2][action] restart for sha %s: exception while dispatching", commit_sha[:8], ) + + notes = "; ".join(notes_parts) + self._logger.insert_event( repo=ctx.repo_full_name, ts=ctx.ts, diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py index 1358d852dd..76fd291bb3 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py @@ -133,6 +133,7 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]: workflow_name=s.workflow_name, commits=new_commits, job_base_name=s.job_base_name, + test_module=s.test_module, source=s.source, ) ) @@ -218,6 +219,7 @@ def _inject_pending_workflow_events( workflow_name=s.workflow_name, commits=new_commits, job_base_name=s.job_base_name, + test_module=s.test_module, source=s.source, ) ) @@ -434,12 +436,17 @@ def _build_test_signals( ) if has_any_events: + # Extract test module from test_id (format: "file.py::test_name") + # Result: "file" or "path/to/file" without .py extension + test_module = test_id.split("::")[0].replace(".py", "") + signals.append( Signal( key=test_id, workflow_name=wf_name, commits=commit_objs, job_base_name=str(job_base_name), + test_module=test_module, source=SignalSource.TEST, ) ) diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/restart_checker.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/restart_checker.py index e7e8e5addf..f75b5f7495 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/restart_checker.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/restart_checker.py @@ -1,4 +1,7 @@ +from typing import FrozenSet, Optional + from ..workflow_checker import WorkflowRestartChecker +from ..workflow_resolver import WorkflowResolver def workflow_restart_checker(workflow: str, commit: str = None, days: int = 7) -> None: @@ -15,3 +18,91 @@ def workflow_restart_checker(workflow: str, commit: str = None, days: int = 7) - print(f" ✓ {commit}") else: print(" No restarted workflows found") + + +def dispatch_workflow_restart( + workflow: str, + commit: str, + jobs: Optional[str] = None, + tests: Optional[str] = None, + repo: str = "pytorch/pytorch", + dry_run: bool = False, +) -> None: + """Dispatch a workflow restart with optional job/test filters. + + Args: + workflow: Workflow name (e.g., "trunk" or "trunk.yml") + commit: Commit SHA to restart + jobs: Space-separated job display names to filter (or None for all) + tests: Space-separated test module paths to filter (or None for all) + repo: Repository in owner/repo format + dry_run: If True, only show what would be dispatched + """ + # Parse filter strings to frozensets + jobs_to_include: FrozenSet[str] = frozenset(jobs.split()) if jobs else frozenset() + tests_to_include: FrozenSet[str] = ( + frozenset(tests.split()) if tests else frozenset() + ) + + # Get workflow resolver and check input support + resolver = WorkflowResolver.get(repo) + wf_ref = resolver.require(workflow) + input_support = resolver.get_input_support(workflow) + + print(f"Workflow: {wf_ref.display_name} ({wf_ref.file_name})") + print(f"Commit: {commit}") + print(f"Repository: {repo}") + print() + + # Show input support status + print("Workflow input support:") + print( + f" jobs-to-include: {'✓ supported' if input_support.jobs_to_include else '✗ not supported'}" + ) + print( + f" tests-to-include: {'✓ supported' if input_support.tests_to_include else '✗ not supported'}" + ) + print() + + # Show what filters will be applied + effective_jobs = jobs_to_include if input_support.jobs_to_include else frozenset() + effective_tests = ( + tests_to_include if input_support.tests_to_include else frozenset() + ) + + if jobs_to_include: + if input_support.jobs_to_include: + print(f"Jobs filter: {' '.join(sorted(jobs_to_include))}") + else: + print( + f"Jobs filter: {' '.join(sorted(jobs_to_include))} (IGNORED - workflow doesn't support)" + ) + + if tests_to_include: + if input_support.tests_to_include: + print(f"Tests filter: {' '.join(sorted(tests_to_include))}") + else: + print( + f"Tests filter: {' '.join(sorted(tests_to_include))} (IGNORED - workflow doesn't support)" + ) + + if not jobs_to_include and not tests_to_include: + print("Filters: none (full CI run)") + + print() + + if dry_run: + print("DRY RUN - would dispatch workflow with above settings") + return + + # Dispatch the workflow + checker = WorkflowRestartChecker( + repo_owner=repo.split("/")[0], repo_name=repo.split("/")[1] + ) + checker.restart_workflow( + workflow, + commit, + jobs_to_include=effective_jobs, + tests_to_include=effective_tests, + ) + print("✓ Workflow dispatched successfully") diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_workflow_resolver.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_workflow_resolver.py index aab26d18c5..fb01869e01 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_workflow_resolver.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_workflow_resolver.py @@ -1,6 +1,9 @@ import os import sys import unittest +from unittest.mock import MagicMock + +import github # Ensure package import when running from repo root @@ -56,5 +59,127 @@ def test_resolve_trunk_workflow(self): ) +class TestParseWorkflowInputs(unittest.TestCase): + """Unit tests for _parse_workflow_inputs without GitHub API calls.""" + + def _create_resolver(self) -> WorkflowResolver: + """Create a resolver with mocked repository for testing.""" + mock_repo = MagicMock(spec=github.Repository.Repository) + mock_repo.get_workflows.return_value = [] + return WorkflowResolver(repo_full_name="test/repo", repository=mock_repo) + + def test_parse_workflow_with_both_inputs(self): + resolver = self._create_resolver() + yaml_content = """ +name: trunk +on: + workflow_dispatch: + inputs: + jobs-to-include: + description: 'Space-separated job names' + required: false + type: string + tests-to-include: + description: 'Space-separated test modules' + required: false + type: string +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertTrue(result.jobs_to_include) + self.assertTrue(result.tests_to_include) + self.assertTrue(result.supports_filtering) + + def test_parse_workflow_with_jobs_only(self): + resolver = self._create_resolver() + yaml_content = """ +name: trunk +on: + workflow_dispatch: + inputs: + jobs-to-include: + type: string +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertTrue(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + + def test_parse_workflow_with_tests_only(self): + resolver = self._create_resolver() + yaml_content = """ +name: trunk +on: + workflow_dispatch: + inputs: + tests-to-include: + type: string +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertFalse(result.jobs_to_include) + self.assertTrue(result.tests_to_include) + + def test_parse_workflow_no_inputs(self): + resolver = self._create_resolver() + yaml_content = """ +name: trunk +on: + workflow_dispatch: +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertFalse(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + self.assertFalse(result.supports_filtering) + + def test_parse_workflow_no_workflow_dispatch(self): + resolver = self._create_resolver() + yaml_content = """ +name: ci +on: + push: + branches: [main] + pull_request: +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertFalse(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + + def test_parse_workflow_simple_on_trigger(self): + resolver = self._create_resolver() + yaml_content = """ +name: simple +on: push +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertFalse(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + + def test_parse_workflow_empty_yaml(self): + resolver = self._create_resolver() + result = resolver._parse_workflow_inputs("") + self.assertFalse(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + + def test_parse_workflow_invalid_yaml(self): + resolver = self._create_resolver() + # Invalid YAML is caught and returns empty support (doesn't raise) + result = resolver._parse_workflow_inputs("{{invalid yaml::") + self.assertFalse(result.jobs_to_include) + self.assertFalse(result.tests_to_include) + + def test_parse_workflow_on_as_true_yaml11(self): + """YAML 1.1 parses 'on' as boolean True in some cases.""" + resolver = self._create_resolver() + # This simulates what happens when YAML parser treats 'on' as True + yaml_content = """ +name: trunk +true: + workflow_dispatch: + inputs: + jobs-to-include: + type: string +""" + result = resolver._parse_workflow_inputs(yaml_content) + self.assertTrue(result.jobs_to_include) + + if __name__ == "__main__": unittest.main() diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py index 64622cb009..50379949d0 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py @@ -5,11 +5,11 @@ import logging from datetime import datetime, timedelta -from typing import Dict, Set +from typing import Dict, FrozenSet, Set from .clickhouse_client_helper import CHCliFactory from .utils import proper_workflow_create_dispatch, RetryWithBackoff -from .workflow_resolver import WorkflowResolver +from .workflow_resolver import WorkflowInputSupport, WorkflowResolver class WorkflowRestartChecker: @@ -110,13 +110,21 @@ def clear_cache(self): """Clear the results cache.""" self._cache.clear() - def restart_workflow(self, workflow_name: str, commit_sha: str) -> None: + def restart_workflow( + self, + workflow_name: str, + commit_sha: str, + jobs_to_include: FrozenSet[str] = frozenset(), + tests_to_include: FrozenSet[str] = frozenset(), + ) -> None: """ - Restart a workflow for a specific commit SHA. + Restart a workflow for a specific commit SHA with optional filtering. Args: workflow_name: Name of the workflow (e.g., "trunk" or "trunk.yml") commit_sha: The commit SHA to restart workflow for + jobs_to_include: Job display names to filter (empty = all jobs) + tests_to_include: Test module paths to filter (empty = all tests) Raises: RuntimeError: If GitHub authentication is not configured. @@ -138,20 +146,54 @@ def restart_workflow(self, workflow_name: str, commit_sha: str) -> None: # Resolve workflow (exact display or file name) wf_ref = self.resolver.require(workflow_name) + # Check what inputs this workflow supports (fail gracefully) + try: + input_support = self.resolver.get_input_support(workflow_name) + except Exception: + logging.warning( + "Failed to check input support for %s, proceeding without filters", + workflow_name, + exc_info=True, + ) + input_support = WorkflowInputSupport() + + # Build inputs dict based on support and available filters + inputs: Dict[str, str] = {} + if input_support.jobs_to_include and jobs_to_include: + inputs["jobs-to-include"] = " ".join(jobs_to_include) + if input_support.tests_to_include and tests_to_include: + inputs["tests-to-include"] = " ".join(tests_to_include) + + # Separate retry scopes: don't retry get_repo/get_workflow on dispatch failure for attempt in RetryWithBackoff(): with attempt: repo = client.get_repo(f"{self.repo_owner}/{self.repo_name}") workflow = repo.get_workflow(wf_ref.file_name) - proper_workflow_create_dispatch(workflow, ref=tag_ref, inputs={}) + + for attempt in RetryWithBackoff(): + with attempt: + proper_workflow_create_dispatch(workflow, ref=tag_ref, inputs=inputs) workflow_url = ( f"https://github.com/{self.repo_owner}/{self.repo_name}" f"/actions/workflows/{wf_ref.file_name}?query=branch%3Atrunk%2F{commit_sha}" ) + + # Log what was dispatched with filter info + filter_info = "" + if inputs: + filter_parts = [] + if "jobs-to-include" in inputs: + filter_parts.append(f"jobs={inputs['jobs-to-include']}") + if "tests-to-include" in inputs: + filter_parts.append(f"tests={inputs['tests-to-include']}") + filter_info = f" with filters: {', '.join(filter_parts)}" + logging.info( - "Successfully dispatched workflow %s for commit %s (run: %s)", + "Successfully dispatched workflow %s for commit %s%s (run: %s)", wf_ref.display_name, commit_sha, + filter_info, workflow_url, ) diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_resolver.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_resolver.py index f344ea9a75..9f8dce3e4e 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_resolver.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_resolver.py @@ -7,6 +7,7 @@ from __future__ import annotations +import logging import os import re from dataclasses import dataclass @@ -14,6 +15,7 @@ from typing import Optional import github +import yaml from .github_client_helper import GHClientFactory from .utils import RetryWithBackoff @@ -27,6 +29,19 @@ class WorkflowRef: file_name: str # basename, e.g., "pull.yml" +@dataclass(frozen=True) +class WorkflowInputSupport: + """Describes which workflow_dispatch inputs a workflow accepts.""" + + jobs_to_include: bool = False + tests_to_include: bool = False + + @property + def supports_filtering(self) -> bool: + """True if workflow supports any filtering inputs.""" + return self.jobs_to_include or self.tests_to_include + + class WorkflowResolver: """Caches workflows for a repo and resolves by exact names. @@ -52,6 +67,7 @@ def __init__( self._repository = repository self._by_display: dict[str, WorkflowRef] = {} self._by_file: dict[str, WorkflowRef] = {} + self._input_support_cache: dict[str, WorkflowInputSupport] = {} self._build_indices() @staticmethod @@ -102,3 +118,88 @@ def _build_indices(self) -> None: ref = WorkflowRef(display_name=name, file_name=base) self._by_display[name] = ref self._by_file[base] = ref + + def get_input_support(self, workflow_name: str) -> WorkflowInputSupport: + """Check if workflow supports filtering inputs by parsing its YAML. + + Args: + workflow_name: Display name or file name of the workflow + + Returns: + WorkflowInputSupport describing which inputs are accepted + """ + ref = self.require(workflow_name) + + if ref.file_name in self._input_support_cache: + return self._input_support_cache[ref.file_name] + + support = self._fetch_and_parse_workflow_inputs(ref.file_name) + self._input_support_cache[ref.file_name] = support + return support + + def _fetch_and_parse_workflow_inputs(self, file_name: str) -> WorkflowInputSupport: + """Fetch workflow YAML from GitHub and parse for dispatch inputs. + + Args: + file_name: Workflow file basename (e.g., "trunk.yml") + + Returns: + WorkflowInputSupport with detected input support + """ + path = f".github/workflows/{file_name}" + + for attempt in RetryWithBackoff(): + with attempt: + contents = self._repository.get_contents(path) + yaml_content = contents.decoded_content.decode("utf-8") + + return self._parse_workflow_inputs(yaml_content) + + def _parse_workflow_inputs(self, yaml_content: str) -> WorkflowInputSupport: + """Parse workflow YAML content to detect supported dispatch inputs. + + Args: + yaml_content: Raw YAML content of the workflow file + + Returns: + WorkflowInputSupport with detected input support + """ + try: + workflow = yaml.safe_load(yaml_content) + if not isinstance(workflow, dict): + return WorkflowInputSupport() + + # YAML 1.1 parses "on" as boolean True, so check for both + on_section = workflow.get("on") or workflow.get(True) or {} + if isinstance(on_section, str): + # Simple trigger like "on: push" + return WorkflowInputSupport() + + workflow_dispatch = on_section.get("workflow_dispatch", {}) + if not isinstance(workflow_dispatch, dict): + return WorkflowInputSupport() + + inputs = workflow_dispatch.get("inputs", {}) + if not isinstance(inputs, dict): + return WorkflowInputSupport() + + support = WorkflowInputSupport( + jobs_to_include="jobs-to-include" in inputs, + tests_to_include="tests-to-include" in inputs, + ) + + logging.debug( + "Workflow %s input support: jobs=%s, tests=%s", + self._repo_full_name, + support.jobs_to_include, + support.tests_to_include, + ) + + return support + + except yaml.YAMLError: + logging.warning("Failed to parse workflow YAML", exc_info=True) + return WorkflowInputSupport() + except Exception: + logging.warning("Unexpected error parsing workflow inputs", exc_info=True) + return WorkflowInputSupport() diff --git a/aws/lambda/pytorch-auto-revert/requirements.txt b/aws/lambda/pytorch-auto-revert/requirements.txt index 7f9b0d391a..13a817d7cb 100644 --- a/aws/lambda/pytorch-auto-revert/requirements.txt +++ b/aws/lambda/pytorch-auto-revert/requirements.txt @@ -1,5 +1,6 @@ boto3==1.38.29 clickhouse-connect==0.8.14 PyGithub==2.6.1 +PyYAML>=6.0 python-dotenv>=1.0.0 requests>=2.31.0