Skip to content

Commit c65e4e7

Browse files
yushangdimeta-codesync[bot]
authored andcommitted
Add backend_options to nested_compile_region to be used by regional_inductor (#167599)
Summary: - Add `aot_config` arg to `nested_compile_region`. The config looks like below. ``` dataclass class NestedCompileRegionOptions: # A Callable that takes (gm, example_inputs, decompositions=None, **kwargs) as inputs. fw_compiler: Optional[Callable] = None bw_compiler: Optional[Callable] = None # Note: [InvokeSubgraphHOP Partitioner] # If not None, add "partitioner" to HOP node meta. # If Callable, directly assign the callable, but the callable cannot be pickled # If str, the options are "default_partition" and "min_cut_rematerialization_partition". # The HOP joint graph will be partitioned using the corresponding functions in # torch/_functorch/partitioners.py partitioner: Optional[Callable | str] = None decompositions: Optional[dict[str, Any]] = None ``` For nested region that has this aot_config option, it will have `node.meta["custom"]["nested_region_config"]=this_option` on the `torch.ops.higher_order.invoke_subgraph` node. The subgraph used in torch.ops.higher_order.invoke_subgraph will be compied with fw/bw_compiler specified in the config. If the compiler is None, it will not be compiled. We provide the following convenient util function to get a config that is used for compiling the nested region using inductor: ``` get_invoke_subgraph_compile_options( inductor_config_patches=None, decompositions=None, partitioner="min_cut_rematerialization_partition", ) ``` Example using this (flex_attention will be compiled with default inductor configs): ``` nested_config = get_invoke_subgraph_compile_options(decompositions=decomp_table) torch.compiler.nested_compile_region(aot_config=nested_config) def f_flex_attention(x, y, z, block_mask, score_mod): x = flex_attention(x, y, z, block_mask=block_mask, score_mod=score_mod) return x def fn(x): x = torch.sin(x) x = f_flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared) return torch.cos(x) ``` How the regional inductor implementation works: We recursively look for `torch.ops.higher_order.invoke_subgraph` nodes that has the `node.meta["custom"]["nested_region_config"]` config. e.g. if the node `invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', arg0, arg1);` has the node meta, we take the submodule corresponding to `partitioned_bw_subgraph_0_0` getattr node, and compile it. The example inputs are obtained from the input node `arg0` and `arg1`'s node.meta["val"]. Then, we replace `invoke_subgraph_3` with a call to the new compiled_submod. We add a `compile_fn` option to `standalone_compile` so `standalone_compile` uses the custom compile fn to compile. https://docs.google.com/document/d/1vkvyP_BH_-tbvjLugmCl4kULsZU1uf1Fag2dAEBhebE/edit?tab=t.0 Nested: You can nest `nested_compile_region`, but you CANNOT nest `nested_compile_region(aot_config=nested_config)`. Only at most 1 `nested_compile_region` in the nested stack should have `aot_config`. You can have multiple `nested_compile_region(aot_config=nested_config)` in parallel, just don't nest them. ``` # valid example torch.compiler.nested_compile_region(aot_config=nested_config) def g(y): return y * 2 torch.compiler.nested_compile_region def gn(x): x = g(x) return torch.sin(x) def fn(x): x = x + 1 x = gn(x) x = x + 1 x = gn(x) return torch.sigmoid(x) ``` Compile time: Let's say I want to compile flex attention for 10 times like: ``` layer = 10 torch.compiler.nested_compile_region(aot_config=nested_config) def f_flex_attention(x, y, z, block_mask, score_mod): x = flex_attention(x, y, z, block_mask=block_mask, score_mod=score_mod) return x def fn2(x): x = torch.sin(x) x = f_flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared) return torch.cos(x) def fn(x): for i in range(layer): x = fn2(x) return x ``` Using annotation: 14,552,485 us <img width="992" height="329" alt="Screenshot 2025-12-01 at 7 58 33 PM" src="https://github.com/user-attachments/assets/903d9e1c-7e9b-42ab-9529-19bb3f32371d" /> Using invoke_subgraph: 6, 466,116 us <img width="988" height="313" alt="Screenshot 2025-12-01 at 8 00 43 PM" src="https://github.com/user-attachments/assets/be09be3e-37ab-4ae5-af65-45bb9dbb3206" /> We can see that using invoke_subgraph is faster. (benchmark code: P2063033678) Caveats/future work: - currently serialization doesn't work yet - the returned values in the compiled region will be detached twice (once in subgraph, once outside), will this hurt performance? - we should test if this is composable with cudagraph using torchtitan X-link: pytorch/pytorch#167599 Approved by: https://github.com/anijain2305, https://github.com/zou3519 Reviewed By: izaitsevfb Differential Revision: D89125275 fbshipit-source-id: a078d7f8bf3feb8fdec4c99232a59294ed0d417a
1 parent 451f4eb commit c65e4e7

File tree

1 file changed

+22
-0
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+22
-0
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/testing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,25 @@ def _skipped_function_for_test_reconstruct(
581581
f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
582582
) -> _T:
583583
return f(*args, **kwargs)
584+
585+
586+
_testing_invoke_subgraph_inductor_compile_captured_gms = None
587+
588+
589+
@contextlib.contextmanager
590+
def _testing_capture_invoke_subgraph_inductor_compile_gms():
591+
"""
592+
Context manager to capture graph modules compiled by invoke_subgraph_inductor_compile.
593+
594+
Usage:
595+
with _testing_capture_invoke_subgraph_inductor_compile_gms() as captured_gms:
596+
# code that triggers invoke_subgraph_inductor_compile
597+
pass
598+
# captured_gms will contain the list of captured graph modules
599+
"""
600+
global _testing_invoke_subgraph_inductor_compile_captured_gms
601+
_testing_invoke_subgraph_inductor_compile_captured_gms = []
602+
try:
603+
yield _testing_invoke_subgraph_inductor_compile_captured_gms
604+
finally:
605+
_testing_invoke_subgraph_inductor_compile_captured_gms = None

0 commit comments

Comments
 (0)