Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 145 additions & 73 deletions frontend/catalyst/python_interface/pass_api/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
"""This file contains PennyLane's API for defining compiler passes."""

from collections.abc import Callable, Sequence
from functools import partial
from inspect import signature
from typing import ClassVar, Union, UnionType, get_args, get_origin

from pennylane.transforms.core import TransformDispatcher
from xdsl.context import Context
from xdsl.dialects import builtin
from xdsl.ir import Operation
Expand All @@ -29,83 +31,64 @@
op_type_rewrite_pattern,
)

from .apply_transform_sequence import register_pass

def _update_type_hints(expected_types: Sequence[type[Operation]]) -> Callable:
"""Update the signature of a ``match_and_rewrite`` method to use the provided operation
as the first argument's type hint."""

if not all(issubclass(e, Operation) for e in expected_types):
raise TypeError(
"Only Operation types or unions of Operation types can be used to "
"register rewrite rules."
)
hint = expected_types[0] if len(expected_types) == 1 else Union[*expected_types]
class PassDefinitionError(Exception):
"""Exception for when a pass is ill-defined."""

def _update_match_and_rewrite(method: Callable) -> Callable:
params = tuple(signature(method).parameters)
# Update type hint of operation argument
# TODO: Is it fine to mutate in-place or should we return a new function?
op_arg_name = params[-2]
method.__annotations__[op_arg_name] = hint

return method

return _update_match_and_rewrite


def _create_rewrite_pattern(
expected_types: Sequence[type[Operation]], rewrite_rule: Callable
) -> RewritePattern:
"""Given a rewrite rule defined as a function, create a ``RewritePattern`` which
can be used with xDSL's pass API."""

# pylint: disable=too-few-public-methods, arguments-differ
class _RewritePattern(RewritePattern):
"""Anonymous rewrite pattern for transforming a matched operation."""
class PassMeta(type):
"""Metaclass for automatically registering xDSL passes with the PennyLane transform API."""

_pass: PLModulePass
def __init__(cls, /):
pass_name = getattr(cls, "name", None)
if not pass_name:
raise PassDefinitionError(
f"The 'name' field must be specified when defining a new PLModulePass."
)

def __init__(self, _pass):
self._pass = _pass
super().__init__()
dispatcher = TransformDispatcher(pass_name=pass_name)
dispatcher.__doc__ = cls.__doc__

@op_type_rewrite_pattern
@_update_type_hints(expected_types)
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
rewrite_rule(self._pass, op, rewriter)
def get_pass_cls():
return cls

return _RewritePattern
register_pass(pass_name, get_pass_cls)
cls._dispatcher = dispatcher


class PLModulePass(ModulePass):
class PLModulePass(ModulePass, metaclass=PassMeta):
"""An xdsl ``ModulePass`` subclass for defining passes."""

name: ClassVar[str]
_rewrite_patterns: ClassVar[dict[Operation, RewritePattern]] = {}
"""The name of the pass."""

def __init__(self, recursive: bool = True, greedy: bool = False):
self.recursive = recursive
self.greedy = greedy
recursive: ClassVar[bool] = True
"""Whether or not the rewrite rules should be applied recursively. If ``True``,
the rewrite rules will be applied repeatedly until a steady-state is reached.
``True`` by default.
"""

@property
def recursive(self):
"""Whether or not the rewrite rules should be applied recursively. If ``True``,
the rewrite rules will be applied repeatedly until a steady-state is reached.
``True`` by default.
"""
return True
greedy: ClassVar[bool] = True
"""Whether or not the rewrite rules should be applied greedily. If ``True``,
each iteration of the rewrite rules' application (if ``recursive == True``)
will only apply the first rewrite rule that modifies the input module.
``True`` by default."""

@property
def greedy(self):
"""Whether or not the rewrite rules should be applied greedily. If ``True``,
each iteration of the rewrite rules' application (if ``recursive == True``)
will only apply the first rewrite rule that modifies the input module.
``True`` by default."""
return True
_rewrite_patterns: ClassVar[dict[Operation, RewritePattern]] = {}
"""Registered rewrite patterns."""

_dispatcher: ClassVar[TransformDispatcher]
"""The ``TransformDispatcher`` instance corresponding to this pass. Subclasses of
``PLModulePass`` are automatically registered with the PennyLane transform API."""

@classmethod
def rewrite_rule(
cls, rule: Callable[["PLModulePass", Operation, PatternRewriter], None]
cls,
hint_or_rule: (
type[Operation] | Callable[["PLModulePass", Operation, PatternRewriter], None]
),
) -> None:
"""Register a rewrite rule.

Expand Down Expand Up @@ -145,22 +128,46 @@ def rewrite_myop(
Returns:
Callable: a decorator to register the rewrite rule with the ModulePass
"""
# xdsl.pattern_rewriter.op_type_rewrite_pattern was used as a reference to
# implement the type hint collection. Source:
# https://github.com/xdslproject/xdsl/blob/main/xdsl/pattern_rewriter.py
params = [param for param in signature(rule, eval_str=True).parameters.values()]
if len(params) != 3 or params[0].name != "self":
raise ValueError(
"The rewrite rule must have 3 arguments, with the first one being 'self'."
)

# If a type hint for the op we're trying to match isn't provided, match all ops
hint = Operation if params[-2] not in rule.__annotations__ else params[-2].annotation
expected_types = [hint] if get_origin(hint) in (Union, UnionType) else get_args(hint)
rewrite_pattern = _create_rewrite_pattern(expected_types, rule)

for et in expected_types:
cls._rewrite_patterns[et] = rewrite_pattern
def rewrite_rule_wrapper(
rule: Callable[["PLModulePass", Operation, PatternRewriter], None],
expected_types: tuple[type[Operation], ...] = None,
):
"""A wrapper to register the input rewrite rule with the transform."""

if not expected_types:
# xdsl.pattern_rewriter.op_type_rewrite_pattern was used as a reference to
# implement the type hint collection. Source:
# https://github.com/xdslproject/xdsl/blob/main/xdsl/pattern_rewriter.py
params = [param for param in signature(rule, eval_str=True).parameters.values()]
if len(params) != 3 or params[0].name != "self":
raise PassDefinitionError(
"The rewrite rule must have 3 arguments, with the first one being 'self'."
)

# If a type hint for the op we're trying to match isn't provided, match all ops
hint = (
Operation if params[-2] not in rule.__annotations__ else params[-2].annotation
)
expected_types = _get_expected_types(hint)
if expected_types is None:
raise PassDefinitionError(
f"The provided rewrite rule, {rule}, uses an invalid type hint, {hint} for "
"the operation to be rewritten."
)

rewrite_pattern = _create_rewrite_pattern(expected_types, rule)
for et in expected_types:
cls._rewrite_patterns[et] = rewrite_pattern

# If the argument is a type hint, then we want to return a decorator
if (expected_types := _get_expected_types(hint_or_rule)) is not None:
return partial(rewrite_rule_wrapper, expected_types=expected_types)

# Else, we just register the rewrite rule, using the rule's type hints
# to get the expected types
rule = hint_or_rule
return rewrite_rule_wrapper(rule)

@property
def rewrite_rules(self):
Expand Down Expand Up @@ -197,3 +204,68 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable=
for rp in self._rewrite_patterns.values():
walker = PatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive)
walker.rewrite_module(op)

def __call__(self, *args, **kwargs):
callee = getattr(self, "_dispatcher", super())
return callee(*args, **kwargs)


def _get_expected_types(maybe_hint) -> tuple[type[Operation], ...] | None:
"""Get the expected types from the input if the input is a type hint. If not,
``None`` will be returned."""
if isinstance(maybe_hint, type) and issubclass(maybe_hint, Operation):
return (maybe_hint,)

if (origin := get_origin(maybe_hint)) is not None:
expected_types = get_args(maybe_hint)
if not (
origin in (Union, UnionType) and all(issubclass(e, Operation) for e in expected_types)
):
raise PassDefinitionError(
"Only Operation types or unions of Operation types can be used to "
f"register rewrite rules. Got {maybe_hint}."
)
return expected_types

return None


def _update_type_hints(expected_types: tuple[type[Operation], ...]) -> Callable:
"""Update the signature of a ``match_and_rewrite`` method to use the provided operation
as the first argument's type hint."""
hint = expected_types[0] if len(expected_types) == 1 else Union[*expected_types]

def _update_match_and_rewrite(method: Callable) -> Callable:
params = tuple(signature(method).parameters)
# Update type hint of operation argument
# TODO: Is it fine to mutate in-place or should we return a new function?
op_arg_name = params[-2]
method.__annotations__[op_arg_name] = hint

return method

return _update_match_and_rewrite


def _create_rewrite_pattern(
expected_types: Sequence[type[Operation]], rewrite_rule: Callable
) -> RewritePattern:
"""Given a rewrite rule defined as a function, create a ``RewritePattern`` which
can be used with xDSL's pass API."""

# pylint: disable=too-few-public-methods, arguments-differ
class _RewritePattern(RewritePattern):
"""Anonymous rewrite pattern for transforming a matched operation."""

_pass: PLModulePass

def __init__(self, _pass):
self._pass = _pass
super().__init__()

@op_type_rewrite_pattern
@_update_type_hints(expected_types)
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
rewrite_rule(self._pass, op, rewriter)

return _RewritePattern
Loading