Skip to content
4 changes: 2 additions & 2 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def uses_transform(qnode, transform_name):
bool: True if `transform_name` is detected (and is only one if only_one=True),
False otherwise
"""
transform_program = getattr(qnode, "transform_program", [])
transform_funcs = [transform_container.transform for transform_container in transform_program]
compile_pipeline = getattr(qnode, "compile_pipeline", [])
transform_funcs = [bound_transform.transform for bound_transform in compile_pipeline]

return any(transform_name in func.__name__ for func in transform_funcs)
2 changes: 1 addition & 1 deletion frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ def is_leaf(obj):
else:
device_program = qml.CompilePipeline()

qnode_program = qnode.transform_program if qnode else qml.CompilePipeline()
qnode_program = qnode.compile_pipeline if qnode else qml.CompilePipeline()

tapes, post_processing, tracing_mode = apply_transforms(
qnode_program,
Expand Down
4 changes: 2 additions & 2 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ def __call__(self, *args, **kwargs):

assert isinstance(self, qml.QNode)

new_transform_program, new_pipeline = _extract_passes(self.transform_program)
new_compile_pipeline, new_pipeline = _extract_passes(self.compile_pipeline)
# Update the qnode with peephole pipeline
old_pipeline = kwargs.pop("pass_pipeline", None)
processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline))
pass_pipeline = processed_old_pipeline + new_pipeline
new_qnode = copy(self)
# pylint: disable=attribute-defined-outside-init, protected-access
new_qnode._transform_program = new_transform_program
new_qnode._compile_pipeline = new_compile_pipeline

# Mid-circuit measurement configuration:
one_shot_results = configure_mcm_and_try_one_shot(new_qnode, args, kwargs, pass_pipeline)
Expand Down
Loading