Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 300 additions & 0 deletions frame/tools/target_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import typing
from abc import ABC, abstractmethod


class TargetLanguage(ABC):
"""Abstract base class for target language operator translations."""

@abstractmethod
def translate_arith_op(self, op: str) -> str:
"""Translate arithmetic operators like +, -, *, /, %, ^"""
pass

@abstractmethod
def translate_logical_op(self, op: str) -> str:
"""Translate comparison operators like ==, >, <, >=, <=, !="""
pass

@abstractmethod
def translate_boolean_op(self, op: str) -> str:
"""Translate boolean operators like And, Or, Implies, Not"""
pass

@abstractmethod
def translate_quantifier(self, quantifier: str) -> str:
"""Translate quantifiers like ForAll, Exists"""
pass

@abstractmethod
def translate_boolean_literal(self, literal: str) -> str:
"""Translate boolean literals like True, False"""
pass

@abstractmethod
def format_function_call(self, function_name: str, args: typing.List[str]) -> str:
"""Format a function call with given name and arguments"""
pass

@abstractmethod
def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str:
"""Format a quantified expression with bound variables and body"""
pass

@abstractmethod
def get_preamble(self) -> str:
"""Get preamble code to add at the beginning of the program"""
pass


class SMTLibTargetLanguage(TargetLanguage):
"""Default SMT-Lib target language implementation."""

def translate_arith_op(self, op: str) -> str:
arith_op_map = {
"+": "+",
"-": "-",
"*": "*",
"/": "div",
"%": "mod",
"^": "^"
}
if op not in arith_op_map:
raise ValueError(f"Unknown arithmetic operator: {op}")
return arith_op_map[op]

def translate_logical_op(self, op: str) -> str:
logical_op_map = {
"==": "=",
">": ">",
"<": "<",
">=": ">=",
"<=": "<=",
"!=": "distinct"
}
if op not in logical_op_map:
raise ValueError(f"Unknown logical operator: {op}")
return logical_op_map[op]

def translate_boolean_op(self, op: str) -> str:
boolean_op_map = {
"And": "and",
"Or": "or",
"Implies": "=>",
"Not": "not"
}
if op not in boolean_op_map:
raise ValueError(f"Unknown boolean operator: {op}")
return boolean_op_map[op]

def translate_quantifier(self, quantifier: str) -> str:
quantifier_map = {
"ForAll": "forall",
"Exists": "exists"
}
if quantifier not in quantifier_map:
raise ValueError(f"Unknown quantifier: {quantifier}")
return quantifier_map[quantifier]

def translate_boolean_literal(self, literal: str) -> str:
literal_map = {
"True": "true",
"False": "false"
}
if literal not in literal_map:
raise ValueError(f"Unknown boolean literal: {literal}")
return literal_map[literal]

def format_function_call(self, function_name: str, args: typing.List[str]) -> str:
if not args:
return function_name
args_str = " ".join(args)
return f"({function_name} {args_str})"

def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str:
# SMT-Lib specific logic for quantified expressions
bound_vars_str = " ".join([f"({var} Int)" for var in bound_vars])

# Add constraints for variables >= 0
constraints = [f"(<= 0 {var})" for var in bound_vars]

if quantifier == "forall":
if constraints:
all_constraints = " ".join(constraints)
body = f"(=> {all_constraints} {body})"
elif quantifier == "exists":
if constraints:
all_constraints = " ".join(constraints)
body = f"(and {all_constraints} {body})"

return f"({quantifier} ({bound_vars_str}) {body})"

def get_preamble(self) -> str:
"""SMT-Lib doesn't need any preamble"""
return ""


class ZnTargetLanguage(TargetLanguage):
"""Z_n finite field target language implementation."""

def __init__(self, n: int):
"""Initialize with modulus n for Z_n finite field"""
if n <= 1:
raise ValueError(f"Modulus n must be greater than 1, got {n}")
self.n = n
self.is_prime = self._is_prime(n)

def _is_prime(self, n: int) -> bool:
"""Check if n is prime"""
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
for i in range(3, int(n**0.5) + 1, 2):
if n % i == 0:
return False
return True

def translate_arith_op(self, op: str) -> str:
arith_op_map = {
"+": "zn_add",
"-": "zn_sub",
"*": "zn_mul",
"/": "zn_div" if self.is_prime else "div", # Only support division for prime fields
"%": "mod", # Regular modulo for other uses
"^": "zn_pow" # Modular exponentiation
}
if op not in arith_op_map:
raise ValueError(f"Unknown arithmetic operator: {op}")
return arith_op_map[op]

def translate_logical_op(self, op: str) -> str:
# Logical operations remain the same
logical_op_map = {
"==": "=",
">": ">",
"<": "<",
">=": ">=",
"<=": "<=",
"!=": "distinct"
}
if op not in logical_op_map:
raise ValueError(f"Unknown logical operator: {op}")
return logical_op_map[op]

def translate_boolean_op(self, op: str) -> str:
# Boolean operations remain the same
boolean_op_map = {
"And": "and",
"Or": "or",
"Implies": "=>",
"Not": "not"
}
if op not in boolean_op_map:
raise ValueError(f"Unknown boolean operator: {op}")
return boolean_op_map[op]

def translate_quantifier(self, quantifier: str) -> str:
# Quantifiers remain the same
quantifier_map = {
"ForAll": "forall",
"Exists": "exists"
}
if quantifier not in quantifier_map:
raise ValueError(f"Unknown quantifier: {quantifier}")
return quantifier_map[quantifier]

def translate_boolean_literal(self, literal: str) -> str:
# Boolean literals remain the same
literal_map = {
"True": "true",
"False": "false"
}
if literal not in literal_map:
raise ValueError(f"Unknown boolean literal: {literal}")
return literal_map[literal]

def format_function_call(self, function_name: str, args: typing.List[str]) -> str:
if not args:
return function_name
args_str = " ".join(args)
return f"({function_name} {args_str})"

def format_quantified_expression(self, quantifier: str, bound_vars: typing.List[str], body: str) -> str:
# For Z_n, variables are constrained to [0, n-1]
bound_vars_str = " ".join([f"({var} Int)" for var in bound_vars])

# Add constraints for variables in range [0, n-1]
# Create individual constraints but don't wrap in (and ...) yet
constraints = []
for var in bound_vars:
constraints.append(f"(<= 0 {var})")
constraints.append(f"(< {var} {self.n})")

# Combine constraint with body based on quantifier
if quantifier == "forall":
# For forall: (=> (and constraints...) body)
if len(constraints) == 1:
constraint_expr = constraints[0]
else:
constraint_expr = f"(and {' '.join(constraints)})"
body = f"(=> {constraint_expr} {body})"
elif quantifier == "exists":
# For exists: (and constraints... body)
if len(constraints) == 0:
# No constraints, just body
pass
elif len(constraints) == 1:
body = f"(and {constraints[0]} {body})"
else:
all_parts = constraints + [body]
body = f"(and {' '.join(all_parts)})"

return f"({quantifier} ({bound_vars_str}) {body})"

def _get_basic_operations(self) -> str:
"""Generate basic Z_n operations (add, sub, mul, pow)"""
return f"""; Z_{self.n} finite field operations
(define-fun zn_add ((x Int) (y Int)) Int (mod (+ x y) {self.n}))
(define-fun zn_sub ((x Int) (y Int)) Int (mod (- x y) {self.n}))
(define-fun zn_mul ((x Int) (y Int)) Int (mod (* x y) {self.n}))
(define-fun zn_pow ((x Int) (y Int)) Int (mod (^ x y) {self.n}))
"""

def _get_inverse_function(self) -> str:
"""Generate modular inverse function for prime fields"""
if not self.is_prime:
return ""

inverse_def = "(define-fun zn_inv ((x Int)) Int\n"
inverse_def += " (ite (= x 0) 0\n" # 0 has no inverse

for i in range(1, self.n):
for j in range(1, self.n):
if (i * j) % self.n == 1:
inverse_def += f" (ite (= x {i}) {j}\n"
break

inverse_def += " 0" + ")" * self.n + ")\n"
return inverse_def

def _get_division_function(self) -> str:
"""Generate division function using modular inverse"""
if not self.is_prime:
return f"; Note: Division not supported for non-prime modulus {self.n}\n"

return f"""(define-fun zn_div ((x Int) (y Int)) Int
(ite (= y 0)
0 ; Division by zero returns 0 (undefined behavior)
(mod (* x (zn_inv y)) {self.n})))
"""

def get_preamble(self) -> str:
"""Generate Z_n finite field function definitions"""
parts = [
self._get_basic_operations(),
self._get_inverse_function(), # Must come before division
self._get_division_function()
]
return "".join(part for part in parts if part.strip())
Loading
Loading