Skip to content

Commit 9a1cda4

Browse files
authored
Merge branch 'main' into export-D88661769
2 parents 1bb583d + 1fe59c8 commit 9a1cda4

File tree

15 files changed

+266
-250
lines changed

15 files changed

+266
-250
lines changed

backends/cadence/aot/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,9 @@ python_unittest(
641641
typing = True,
642642
deps = [
643643
"//caffe2:torch",
644+
"//executorch/backends/cadence/aot:graph_builder",
644645
"//executorch/backends/cadence/aot/quantizer:quantizer",
646+
"//executorch/exir:pass_base",
647+
"//pytorch/ao:torchao",
645648
],
646649
)

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,64 @@
99
import unittest
1010

1111
import torch
12+
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
1213
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
1314

1415
from executorch.backends.cadence.aot.quantizer.quantizer import (
1516
CadenceAtenQuantizer,
1617
CadenceDefaultQuantizer,
1718
CadenceW8A32MixedQuantizer,
19+
CadenceWith16BitMatmulActivationsQuantizer,
20+
qconfig_A16,
1821
qconfig_A8W8,
1922
)
23+
from executorch.exir.pass_base import NodeMetadata
24+
from torchao.quantization.pt2e.quantizer.quantizer import (
25+
Q_ANNOTATION_KEY,
26+
QuantizationAnnotation,
27+
)
28+
29+
30+
class QuantizerAnnotationTest(unittest.TestCase):
31+
"""Unit tests for verifying quantizer annotations are correctly applied."""
32+
33+
def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
34+
"""Build a simple graph with a matmul operation."""
35+
builder = GraphBuilder()
36+
x = builder.placeholder("x", torch.randn(4, 8))
37+
y = builder.placeholder("y", torch.randn(8, 4))
38+
matmul = builder.call_operator(
39+
op=torch.ops.aten.matmul.default,
40+
args=(x, y),
41+
meta=NodeMetadata(
42+
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
43+
),
44+
)
45+
builder.output([matmul])
46+
gm = builder.get_graph_module()
47+
48+
matmul_nodes = gm.graph.find_nodes(
49+
op="call_function",
50+
target=torch.ops.aten.matmul.default,
51+
)
52+
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
53+
return gm, matmul_nodes[0]
54+
55+
def test_matmul_16bit_quantizer_annotation(self) -> None:
56+
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
57+
gm, matmul_node = self._build_matmul_graph()
58+
59+
quantizer = CadenceWith16BitMatmulActivationsQuantizer()
60+
quantizer.annotate(gm)
61+
62+
annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
63+
self.assertTrue(annotation._annotated)
64+
65+
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
66+
67+
self.assertEqual(len(annotation.input_qspec_map), 2)
68+
for _, input_qspec in annotation.input_qspec_map.items():
69+
self.assertEqual(input_qspec, qconfig_A16.input_activation)
2070

2171

2272
class QuantizerOpsPreserveTest(unittest.TestCase):

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ def _get_convolution_replacement(self, node) -> int:
171171
weight_permuted,
172172
)
173173

174+
quantized_multiplier_tensor = create_constant_placeholder(
175+
self.exported_program,
176+
node.graph,
177+
node.name + "_quantized_multiplier",
178+
InputKind.PARAMETER,
179+
torch.tensor(quantized_multipliers, dtype=torch.int32),
180+
)
181+
182+
quantized_shift_tensor = create_constant_placeholder(
183+
self.exported_program,
184+
node.graph,
185+
node.name + "_quantized_shift",
186+
InputKind.PARAMETER,
187+
torch.tensor(quantized_shifts, dtype=torch.int32),
188+
)
189+
174190
new_args = (
175191
x,
176192
weight_nhwc,
@@ -180,8 +196,8 @@ def _get_convolution_replacement(self, node) -> int:
180196
dilation,
181197
-input_zero_point,
182198
output_zero_point,
183-
torch.tensor(quantized_multipliers, dtype=torch.int32),
184-
torch.tensor(quantized_shifts, dtype=torch.int32),
199+
quantized_multiplier_tensor,
200+
quantized_shift_tensor,
185201
output_qmin,
186202
output_qmax,
187203
)

backends/qualcomm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Please check `generate_qnn_executorch_compiler_spec()` in
3131
- SXR2330P
3232
- QCS9100
3333
- SAR2230P
34+
- SW6100
3435

3536
### Adding more supported Chipset
3637
Currently, users cannot add additional chipset models because the chipset ID is not accessible to community users. If you have specific chipset models you wish to add, please contact one of the authors in the `Code Reviews` section at the bottom of this page.

backends/qualcomm/serialization/qc_compiler_spec.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum QcomChipset: int {
4949
QCS9100 = 77,
5050
SAR2230P = 95,
5151
SA8255 = 52,
52+
SW6100 = 96,
5253
}
5354

5455
/// Indicate the information of the specified SoC.

backends/qualcomm/serialization/qc_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class QcomChipset(IntEnum):
5555
QCS9100 = 77 # v73
5656
SAR2230P = 95 # v81
5757
SA8255 = 52 # v73
58+
SW6100 = 96 # v81
5859

5960

6061
@dataclass
@@ -80,6 +81,7 @@ class SocInfo:
8081
QcomChipset.SXR2330P: SocInfo(QcomChipset.SXR2330P, HtpInfo(HtpArch.V79, 8)),
8182
QcomChipset.QCS9100: SocInfo(QcomChipset.QCS9100, HtpInfo(HtpArch.V73, 8)),
8283
QcomChipset.SAR2230P: SocInfo(QcomChipset.SAR2230P, HtpInfo(HtpArch.V81, 4)),
84+
QcomChipset.SW6100: SocInfo(QcomChipset.SW6100, HtpInfo(HtpArch.V81, 4)),
8385
}
8486

8587

backends/qualcomm/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,7 @@ def get_soc_to_arch_map():
11061106
"SXR2330P": HtpArch.V79,
11071107
"QCS9100": HtpArch.V73,
11081108
"SAR2230P": HtpArch.V81,
1109+
"SW6100": HtpArch.V81,
11091110
}
11101111

11111112

@@ -1127,6 +1128,7 @@ def get_soc_to_chipset_map():
11271128
"SXR2330P": QcomChipset.SXR2330P,
11281129
"QCS9100": QcomChipset.QCS9100,
11291130
"SAR2230P": QcomChipset.SAR2230P,
1131+
"SW6100": QcomChipset.SW6100,
11301132
}
11311133

11321134

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
272272
norm_eps=model_args.norm_eps,
273273
)
274274
)
275+
elif (
276+
model_args.layer_types
277+
and model_args.layer_types[layer_id] == "skip_attention"
278+
):
279+
attention = AttentionSkip()
280+
transformer_block = TransformerBlock(model_args, attention)
281+
layers.append(transformer_block)
275282
else:
276283
attention = cls(
277284
model_args, layer_id, rope, **model_args.attention_kwargs

examples/portable/executor_runner/executor_runner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
*/
2020

2121
#include <cstdint>
22+
#ifdef ET_BUNDLE_IO_ENABLED
2223
#include <filesystem>
24+
#endif // ET_BUNDLE_IO_ENABLED
2325
#include <fstream>
2426
#include <iostream>
2527
#include <memory>

exir/passes/spec_prop_pass.py

Lines changed: 46 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
# pyre-strict
88

9-
from typing import List, Optional
9+
import operator
10+
from typing import Optional
1011

1112
import torch
1213
from executorch.exir.delegate import executorch_call_delegate
13-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
14+
from executorch.exir.pass_base import ExportPass, ProxyValue
1415
from executorch.exir.tensor import TensorSpec
1516
from torch.export.exported_program import ExportGraphSignature
1617
from torch.fx.node import Node
18+
from torch.fx.passes.infra.pass_base import PassResult
1719
from torch.utils import _pytree as pytree
1820

1921

@@ -52,12 +54,48 @@ class SpecPropPass(ExportPass):
5254
def __init__(self) -> None:
5355
super().__init__()
5456

55-
def on_attr(self, attr: ProxyValue) -> None:
56-
attr.node.meta["spec"] = pytree.tree_map_only(
57-
torch.Tensor,
58-
make_spec,
59-
attr.data,
60-
)
57+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
58+
# Re-trace metadata to ensure it's up to date.
59+
res = ExportPass()(graph_module)
60+
assert res is not None
61+
gm = res.graph_module
62+
63+
def get_spec(x):
64+
if hasattr(x, "meta"):
65+
return x.meta.get("spec", None)
66+
else:
67+
return None
68+
69+
for module in gm.modules():
70+
if isinstance(module, torch.fx.GraphModule):
71+
for node in module.graph.nodes:
72+
meta_val = node.meta.get("val", None)
73+
74+
if node.op == "output":
75+
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
76+
elif node.op == "call_function" and node.target == operator.getitem:
77+
value_spec = pytree.tree_map(get_spec, node.args[0])
78+
node.meta["spec"] = value_spec[node.args[1]]
79+
elif (
80+
node.op == "call_function"
81+
and node.target == executorch_call_delegate
82+
):
83+
# Note: We currently rely on delegate node specs not being regenerated,
84+
# as the spec is set somewhat manually when adding the call delegate node.
85+
# If we regenerate, it can change and break lowering (it becomes a tuple?).
86+
# Ideally, we should figure out how to make the spec regeneration not break
87+
# things.
88+
#
89+
# We do need to regenerate non-call-delegate node specs, as this pass is called
90+
# multiple times in some lowering paths (backends can and do call it).
91+
if "spec" not in node.meta:
92+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
93+
else:
94+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
95+
return res
96+
97+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
98+
return self(graph_module)
6199

62100
def update_placeholder_tensor_specs(
63101
self,
@@ -84,85 +122,3 @@ def update_placeholder_tensor_specs(
84122
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
85123
):
86124
spec.const = True
87-
88-
# pyre-ignore
89-
def placeholder(self, name: str, arg, meta):
90-
meta["spec"] = make_spec(arg)
91-
return super().placeholder(name, arg, meta)
92-
93-
# pyre-ignore
94-
def call_operator(self, op, args, kwargs, meta):
95-
args_data, kwargs_data = pytree.tree_map_only(
96-
ProxyValue, lambda x: x.data, (args, kwargs)
97-
)
98-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
99-
return super().call_operator(op, args, kwargs, meta)
100-
101-
# pyre-ignore
102-
def call_getitem(self, value, key: int, meta):
103-
meta["spec"] = value.node.meta["spec"][key]
104-
return super().call_getitem(value, key, meta)
105-
106-
# pyre-ignore
107-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
108-
# true_fn/false_fn return tensors of the same shape, so we can pick
109-
# either one here.
110-
*_, true_out_node = true_fn.graph.nodes
111-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
112-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
113-
114-
def call_while(
115-
self,
116-
cond_fn: torch.fx.GraphModule,
117-
body_fn: torch.fx.GraphModule,
118-
carried_inputs: List[ProxyValue],
119-
additional_inputs: List[ProxyValue],
120-
meta: NodeMetadata,
121-
):
122-
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123-
return super().call_while(
124-
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125-
)
126-
127-
def call_map(
128-
self,
129-
f: torch.fx.GraphModule,
130-
mapped_args: List[ProxyValue],
131-
operands: List[ProxyValue],
132-
meta: NodeMetadata,
133-
) -> ProxyValue:
134-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
135-
*_, body_out_node = f.graph.nodes
136-
body_out_node_fake_tensor = body_out_node.meta["val"]
137-
map_fake_tensor = pytree.tree_map_only(
138-
torch.Tensor,
139-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
140-
body_out_node_fake_tensor,
141-
)
142-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
143-
return super().call_map(f, mapped_args, operands, meta)
144-
145-
# pyre-ignore
146-
def call_delegate(self, lowered_module, args, kwargs, meta):
147-
args_data, kwargs_data = pytree.tree_map_only(
148-
ProxyValue, lambda x: x.data, (args, kwargs)
149-
)
150-
# If spec is missing, re-genenrate it with args data
151-
if "spec" not in meta:
152-
meta["spec"] = pytree.tree_map(
153-
make_spec,
154-
executorch_call_delegate(lowered_module, *args_data),
155-
)
156-
return super().call_delegate(lowered_module, args, kwargs, meta)
157-
158-
# pyre-ignore
159-
def output(self, results, meta):
160-
# pyre-ignore
161-
def get_spec(x):
162-
if isinstance(x, ProxyValue):
163-
return x.node.meta["spec"]
164-
else:
165-
return make_spec(x)
166-
167-
meta["spec"] = pytree.tree_map(get_spec, results)
168-
return super().output(results, meta)

0 commit comments

Comments
 (0)