Commit c65e4e7
Add backend_options to nested_compile_region to be used by regional_inductor (#167599)
Summary:
- Add `aot_config` arg to `nested_compile_region`. The config looks like below.
```
dataclass
class NestedCompileRegionOptions:
# A Callable that takes (gm, example_inputs, decompositions=None, **kwargs) as inputs.
fw_compiler: Optional[Callable] = None
bw_compiler: Optional[Callable] = None
# Note: [InvokeSubgraphHOP Partitioner]
# If not None, add "partitioner" to HOP node meta.
# If Callable, directly assign the callable, but the callable cannot be pickled
# If str, the options are "default_partition" and "min_cut_rematerialization_partition".
# The HOP joint graph will be partitioned using the corresponding functions in
# torch/_functorch/partitioners.py
partitioner: Optional[Callable | str] = None
decompositions: Optional[dict[str, Any]] = None
```
For nested region that has this aot_config option, it will have `node.meta["custom"]["nested_region_config"]=this_option` on the `torch.ops.higher_order.invoke_subgraph` node.
The subgraph used in torch.ops.higher_order.invoke_subgraph will be compied with fw/bw_compiler specified in the config. If the compiler is None, it will not be compiled.
We provide the following convenient util function to get a config that is used for compiling the nested region using inductor:
```
get_invoke_subgraph_compile_options(
inductor_config_patches=None,
decompositions=None,
partitioner="min_cut_rematerialization_partition",
)
```
Example using this (flex_attention will be compiled with default inductor configs):
```
nested_config = get_invoke_subgraph_compile_options(decompositions=decomp_table)
torch.compiler.nested_compile_region(aot_config=nested_config)
def f_flex_attention(x, y, z, block_mask, score_mod):
x = flex_attention(x, y, z, block_mask=block_mask, score_mod=score_mod)
return x
def fn(x):
x = torch.sin(x)
x = f_flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
```
How the regional inductor implementation works:
We recursively look for `torch.ops.higher_order.invoke_subgraph` nodes that has the `node.meta["custom"]["nested_region_config"]` config.
e.g. if the node `invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', arg0, arg1);` has the node meta, we take the submodule corresponding to `partitioned_bw_subgraph_0_0` getattr node, and compile it.
The example inputs are obtained from the input node `arg0` and `arg1`'s node.meta["val"].
Then, we replace `invoke_subgraph_3` with a call to the new compiled_submod.
We add a `compile_fn` option to `standalone_compile` so `standalone_compile` uses the custom compile fn to compile.
https://docs.google.com/document/d/1vkvyP_BH_-tbvjLugmCl4kULsZU1uf1Fag2dAEBhebE/edit?tab=t.0
Nested:
You can nest `nested_compile_region`, but you CANNOT nest `nested_compile_region(aot_config=nested_config)`. Only at most 1 `nested_compile_region` in the nested stack should have `aot_config`.
You can have multiple `nested_compile_region(aot_config=nested_config)` in parallel, just don't nest them.
```
# valid example
torch.compiler.nested_compile_region(aot_config=nested_config)
def g(y):
return y * 2
torch.compiler.nested_compile_region
def gn(x):
x = g(x)
return torch.sin(x)
def fn(x):
x = x + 1
x = gn(x)
x = x + 1
x = gn(x)
return torch.sigmoid(x)
```
Compile time:
Let's say I want to compile flex attention for 10 times like:
```
layer = 10
torch.compiler.nested_compile_region(aot_config=nested_config)
def f_flex_attention(x, y, z, block_mask, score_mod):
x = flex_attention(x, y, z, block_mask=block_mask, score_mod=score_mod)
return x
def fn2(x):
x = torch.sin(x)
x = f_flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
def fn(x):
for i in range(layer):
x = fn2(x)
return x
```
Using annotation:
14,552,485 us
<img width="992" height="329" alt="Screenshot 2025-12-01 at 7 58 33 PM" src="https://github.com/user-attachments/assets/903d9e1c-7e9b-42ab-9529-19bb3f32371d" />
Using invoke_subgraph:
6, 466,116 us
<img width="988" height="313" alt="Screenshot 2025-12-01 at 8 00 43 PM" src="https://github.com/user-attachments/assets/be09be3e-37ab-4ae5-af65-45bb9dbb3206" />
We can see that using invoke_subgraph is faster. (benchmark code: P2063033678)
Caveats/future work:
- currently serialization doesn't work yet
- the returned values in the compiled region will be detached twice (once in subgraph, once outside), will this hurt performance?
- we should test if this is composable with cudagraph using torchtitan
X-link: pytorch/pytorch#167599
Approved by: https://github.com/anijain2305, https://github.com/zou3519
Reviewed By: izaitsevfb
Differential Revision: D89125275
fbshipit-source-id: a078d7f8bf3feb8fdec4c99232a59294ed0d417a1 parent 451f4eb commit c65e4e7
1 file changed
+22
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
581 | 581 | | |
582 | 582 | | |
583 | 583 | | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
0 commit comments