Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
64dd86a
Implement no-cloning validation
zhenrongliew Nov 5, 2025
1b773d4
improve error reporting and update test cases for validation errors
zhenrongliew Nov 6, 2025
dca2422
Updated ValidationError reporting
zhenrongliew Nov 6, 2025
40f5d6c
Refactor no-cloning validation: enhance error handling and improve te…
zhenrongliew Nov 6, 2025
2a98e55
Shorter error messages
zhenrongliew Nov 6, 2025
5f4500d
clarify join/meet and lattice structure
zhenrongliew Nov 6, 2025
ef5a4e1
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 7, 2025
80c9349
fix linting
zhenrongliew Nov 7, 2025
6429148
Fix import warning
zhenrongliew Nov 7, 2025
2e155d9
fix import errors
zhenrongliew Nov 7, 2025
2a011fc
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 7, 2025
a35636a
Improve Validation framework to compose multiple validation analyses.
zhenrongliew Nov 7, 2025
2bf369e
removed redundant `method` variable
zhenrongliew Nov 7, 2025
a2af2f3
Fix commutativity of `join` operation
zhenrongliew Nov 10, 2025
c3680c4
Merge branch 'main' into dl/validate-no-cloning
zhenrongliew Nov 10, 2025
b572471
updated to work with new Kirin version
zhenrongliew Nov 10, 2025
9ccc42c
moved collecting errors to Kirin's InterpreterABC
zhenrongliew Nov 12, 2025
ca8f5cd
fix unused import
zhenrongliew Nov 12, 2025
aeaea1c
Moved ValidationPass to Kirin
zhenrongliew Nov 12, 2025
06c7bbb
Remove redundant code in ifelse handling.
zhenrongliew Nov 14, 2025
924d60a
Refactor validation analysis and error handling in NoCloningValidatio…
zhenrongliew Nov 17, 2025
c2c825e
Fix import
zhenrongliew Nov 17, 2025
3657581
use `raise_if_invalid` instead of `format_errors`.
zhenrongliew Nov 19, 2025
20654c5
Merge branch 'main' into dl/validate-no-cloning
david-pl Jan 26, 2026
8deed42
Fix error check in test
david-pl Jan 26, 2026
bfbe1f0
Merge branch 'main' into dl/validate-no-cloning
david-pl Jan 26, 2026
21308ab
Make func.Invoke analysis an actual method impl
david-pl Jan 26, 2026
bdebdc6
Store addresses in frame after invoke
david-pl Jan 28, 2026
f1c289d
Verify whether there is an error on func.Invoke
david-pl Jan 28, 2026
7e3973d
Revert "Store addresses in frame after invoke"
david-pl Jan 28, 2026
fb113e8
Add a failing test for debugging
david-pl Jan 29, 2026
94d0da4
AddressAnalysis: use custom frame to collect addresses in func.Invoke…
david-pl Jan 29, 2026
1e84170
Also collect invoke addresses in ifelse branches
david-pl Jan 29, 2026
4607a2c
Fix no-cloning for multiple stdlib calls with different args
david-pl Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bloqade/analysis/address/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
UnknownQubit as UnknownQubit,
PartialLambda as PartialLambda,
)
from .analysis import AddressAnalysis as AddressAnalysis
from .analysis import AddressFrame as AddressFrame, AddressAnalysis as AddressAnalysis
78 changes: 75 additions & 3 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,83 @@
from typing import Any, Type, TypeVar
from dataclasses import field
from contextlib import contextmanager
from dataclasses import field, dataclass

from kirin import ir, types, interp
from kirin.analysis import Forward, const
from kirin.analysis import ForwardExtra, const
from kirin.dialects import func
from kirin.dialects.ilist import IList
from kirin.analysis.forward import ForwardFrame
from kirin.analysis.const.lattice import PartialLambda

from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple

InvokeKey = tuple[ir.Statement, tuple[int, ...]]


@dataclass
class AddressFrame(ForwardFrame[Address]):
_current_invoke_key: InvokeKey | None = None
_invoke_addresses: dict[InvokeKey, dict[ir.SSAValue, Address]] = field(
init=False, default_factory=dict
)

def collect_invoke_addresses(
self, call_frame: "AddressFrame", node: func.Invoke | None = None
):
if node is not None:
inputs = self.get_values(node.inputs)
input_ids = tuple(map(id, inputs))
key = (node, input_ids)

# collect the addresses found in the function body
data = self._invoke_addresses.get(key, dict())
data.update(call_frame.entries)
self._invoke_addresses[key] = data

# collect nested invokes
self._invoke_addresses.update(call_frame._invoke_addresses)

@contextmanager
def invoke_addresses(self, node: func.Invoke):
inputs = [self.get_or_fallback_to_invoke(input_) for input_ in node.inputs]
input_ids = tuple(map(id, inputs))
context_key = (node, input_ids)

reset_invoke_key = self._current_invoke_key
self._current_invoke_key = context_key
try:
yield self
finally:
self._current_invoke_key = reset_invoke_key

def get_or_fallback_to_invoke(self, key: ir.SSAValue):
"""Modified frame.get method that also checks addresses collected from
function invokes.
"""
value = self.entries.get(key, interp.Undefined)

if not interp.is_undefined(value):
return value

if self._current_invoke_key is not None:
additional_entries = self._invoke_addresses.get(
self._current_invoke_key, dict()
)
value = additional_entries.get(key, interp.Undefined)

if not interp.is_undefined(value):
return value

class AddressAnalysis(Forward[Address]):
if self.has_parent_access and self.parent:
if isinstance(self.parent, AddressFrame):
self.parent.get_or_fallback_to_invoke(key)
else:
return self.parent.get(key)

raise interp.InterpreterError(f"SSAValue {key} not found")


class AddressAnalysis(ForwardExtra[AddressFrame, Address]):
"""
This analysis pass can be used to track the global addresses of qubits and wires.
"""
Expand All @@ -36,6 +103,11 @@ def initialize(self):
self._const_prop.initialize()
return self

def initialize_frame(
self, node: ir.Statement, *, has_parent_access: bool = False
) -> AddressFrame:
return AddressFrame(node, has_parent_access=has_parent_access)

@property
def qubit_count(self) -> int:
"""Total number of qubits found by the analysis."""
Expand Down
13 changes: 9 additions & 4 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PartialTuple,
PartialLambda,
)
from .analysis import AddressAnalysis
from .analysis import AddressFrame, AddressAnalysis


@py.constant.dialect.register(key="qubit.address")
Expand Down Expand Up @@ -177,15 +177,17 @@ def return_(
def invoke(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
frame: AddressFrame,
stmt: func.Invoke,
):
_, ret = interp_.call(
call_frame, ret = interp_.call(
stmt.callee.code,
interp_.method_self(stmt.callee),
*frame.get_values(stmt.inputs),
)

frame.collect_invoke_addresses(call_frame, stmt)

return (ret,)

@interp.impl(func.Lambda)
Expand Down Expand Up @@ -307,7 +309,7 @@ def yield_(
def ifelse(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
frame: AddressFrame,
stmt: scf.IfElse,
):
address_cond = frame.get(stmt.cond)
Expand All @@ -318,6 +320,7 @@ def ifelse(
body = stmt.then_body if const_cond.data else stmt.else_body
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
frame.collect_invoke_addresses(body_frame)
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
return ret
else:
Expand All @@ -330,6 +333,7 @@ def ifelse(
address_cond,
)
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
frame.collect_invoke_addresses(then_frame)

with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
else_results = interp_.frame_call_region(
Expand All @@ -339,6 +343,7 @@ def ifelse(
address_cond,
)
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
frame.collect_invoke_addresses(else_frame)
# TODO: pick the non-return value
if isinstance(then_results, interp.ReturnValue) and isinstance(
else_results, interp.ReturnValue
Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/analysis/validation/nocloning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import impls as impls
from .analysis import NoCloningValidation as NoCloningValidation
180 changes: 180 additions & 0 deletions src/bloqade/analysis/validation/nocloning/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import Any
from dataclasses import dataclass

from kirin import ir
from kirin.analysis import Forward
from kirin.ir.exception import (
ValidationError,
DefiniteValidationError,
PotentialValidationError,
)
from kirin.analysis.forward import ForwardFrame
from kirin.validation.validationpass import ValidationPass

from bloqade.analysis.address import AddressFrame, AddressAnalysis

from .lattice import May, Must, Bottom, QubitValidation


class QubitValidationError(DefiniteValidationError):
"""ValidationError for definite (Must) violations with concrete qubit addresses."""

qubit_id: int
gate_name: str

def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str):
super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.")
self.qubit_id = qubit_id
self.gate_name = gate_name


class PotentialQubitValidationError(PotentialValidationError):
"""ValidationError for potential (May) violations with unknown addresses."""

gate_name: str
condition: str

def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.")
self.gate_name = gate_name
self.condition = condition


@dataclass
class _NoCloningAnalysis(Forward[QubitValidation]):
"""Internal forward analysis for tracking qubit cloning violations."""

keys = ("validate.nocloning",)
lattice = QubitValidation
_address_frame: AddressFrame | None = None

def method_self(self, method: ir.Method) -> QubitValidation:
return self.lattice.bottom()

def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation):
if self._address_frame is None:
addr_analysis = AddressAnalysis(self.dialects)
addr_analysis.initialize()
self._address_frame, _ = addr_analysis.run(method)
return super().run(method, *args, **kwargs)

def eval_fallback(
self, frame: ForwardFrame[QubitValidation], node: ir.Statement
) -> tuple[QubitValidation, ...]:
"""Check for qubit usage violations and return lattice values."""
return tuple(Bottom() for _ in node.results)

def _get_source_name(self, value: ir.SSAValue) -> str:
"""Trace back to get the source variable name."""
from kirin.dialects.py.indexing import GetItem

if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem):
index_arg = value.stmt.args[1]
return self._get_source_name(index_arg)

if isinstance(value, ir.BlockArgument):
return value.name or f"arg{value.index}"

if hasattr(value, "name") and value.name:
return value.name

return str(value)

def extract_errors_from_frame(
self, frame: ForwardFrame[QubitValidation]
) -> list[ValidationError]:
"""Extract validation errors from final lattice values.

Only extracts errors from top-level statements (not nested in regions).
"""
errors = []
seen_statements = set()

for node, value in frame.entries.items():
if isinstance(node, ir.ResultValue):
stmt = node.stmt
elif isinstance(node, ir.Statement):
stmt = node
else:
continue
if stmt in seen_statements:
continue
seen_statements.add(stmt)
if isinstance(value, Must):
for qubit_id, gate_name in value.violations:
errors.append(QubitValidationError(stmt, qubit_id, gate_name))
elif isinstance(value, May):
for gate_name, condition in value.violations:
errors.append(
PotentialQubitValidationError(stmt, gate_name, condition)
)
return errors

def count_violations(self, frame: Any) -> int:
"""Count individual violations from the frame, same as test helper."""
from .lattice import May, Must

total = 0
for node, value in frame.entries.items():
if isinstance(value, Must):
total += len(value.violations)
elif isinstance(value, May):
total += len(value.violations)
return total


class NoCloningValidation(ValidationPass):
"""Validates the no-cloning theorem by tracking qubit addresses."""

def __init__(self):
self._analysis: _NoCloningAnalysis | None = None
self._cached_address_frame = None

def name(self) -> str:
return "No-Cloning Validation"

def get_required_analyses(self) -> list[type]:
"""Declare dependency on AddressAnalysis."""
return [AddressAnalysis]

def set_analysis_cache(self, cache: dict[type, Any]) -> None:
"""Use cached AddressAnalysis result."""
self._cached_address_frame = cache.get(AddressAnalysis)

def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
"""Run the no-cloning validation analysis."""
if self._analysis is None:
self._analysis = _NoCloningAnalysis(method.dialects)

self._analysis.initialize()
if self._cached_address_frame is not None:
self._analysis._address_frame = self._cached_address_frame

frame, _ = self._analysis.run(method)
errors = self._analysis.extract_errors_from_frame(frame)

return frame, errors

def print_validation_errors(self):
"""Print all collected errors with formatted snippets."""
if self._analysis is None:
return

if self._analysis.state._current_frame:
frame = self._analysis.state._current_frame
errors = self._analysis.extract_errors_from_frame(frame)

for err in errors:
if isinstance(err, QubitValidationError):
print(
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
)
elif isinstance(err, PotentialQubitValidationError):
print(
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
)
else:
print(
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
)
print(err.hint())
Loading
Loading