Skip to content

Commit 617091b

Browse files
committed
fixed issues with different types
1 parent f2c880b commit 617091b

File tree

5 files changed

+54
-58
lines changed

5 files changed

+54
-58
lines changed

grape/automaton_generator.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,16 @@ def transition(p: Program, args: tuple[int, ...]) -> int | None:
6060
)
6161

6262

63-
def commutativity_constraint(
64-
dsl: DSL, commutatives: list[tuple[str, list[int]]], type_req: str
65-
) -> Constraint:
66-
arg_types, target_type = types.parse(type_req)
63+
def commutativity_constraint(commutatives: list[tuple[str, list[int]]]) -> Constraint:
6764
to_check: dict[str, list[tuple[int, int]]] = {}
6865
for p, swapped in commutatives:
6966
if p not in to_check:
7067
to_check[p] = []
7168
to_check[p].append((swapped[0], swapped[1]))
7269

7370
def transition(p: Program, args: tuple[str, ...]) -> str | None:
74-
if isinstance(p, Function):
75-
key = str(p.function)
71+
if isinstance(p, Primitive):
72+
key = str(p)
7673
if key in to_check:
7774
if all(args[i] <= args[j] for i, j in to_check[key]):
7875
return key
@@ -83,17 +80,10 @@ def transition(p: Program, args: tuple[str, ...]) -> str | None:
8380
else:
8481
return str(p)
8582

86-
whatever = target_type.lower() == "none"
87-
8883
return Constraint(
8984
str,
9085
transition,
91-
lambda s: whatever
92-
or (
93-
types.return_type(dsl.primitives[s][0]) == target_type
94-
if s in dsl.primitives
95-
else arg_types[int(s[len("var") :])]
96-
),
86+
lambda _: True,
9787
)
9888

9989

@@ -187,17 +177,17 @@ def grammar_from_memory(
187177
type_req: str,
188178
prev_finals: set[str],
189179
) -> tuple[DFTA[str, Program], int]:
180+
"""
181+
Returns (specialized grammar, 1)
182+
"""
190183
max_size = max(max(memory[state].keys()) for state in memory)
191184
args_type = types.arguments(type_req)
192185
# Compute variable merging: all variables of same type should be merged
193-
var_merge = {}
194186
type2var = {}
195-
for i, t in enumerate(args_type):
196-
if t in type2var:
197-
var_merge[i] = type2var[t]
198-
else:
199-
type2var[t] = i
200-
var_merge[i] = i
187+
for t in args_type:
188+
if t not in type2var:
189+
type2var[t] = len(type2var)
190+
var_merge = {i: type2var[t] for i, t in enumerate(args_type)}
201191
# Produce rules incrementally
202192
rules: dict[tuple[Program, tuple[str, ...]], str] = {}
203193
finals: set[str] = set()
@@ -223,22 +213,22 @@ def grammar_from_memory(
223213

224214
# Reproduce original type request to compare number of programs
225215
# add a rule for each deleted variable
226-
added = set()
227-
for i, j in var_merge.items():
228-
if i == j:
229-
continue
230-
# data: variable i is renamed as variable j
231-
old = Variable(i)
232-
for (prog, _), dst in relevant_dfta.rules.copy().items():
233-
if isinstance(prog, Variable) and prog.no == j:
234-
relevant_dfta.rules[(old, ())] = dst
235-
added.add((old, ()))
236-
relevant_dfta.refresh_reversed_rules()
237-
n = relevant_dfta.trees_until_size(max_size)
238-
# Delete them now that they have been used
239-
for x in added:
240-
del relevant_dfta.rules[x]
241-
relevant_dfta.refresh_reversed_rules()
216+
# added = set()
217+
# for i, j in var_merge.items():
218+
# if i == j:
219+
# continue
220+
# # data: variable i is renamed as variable j
221+
# old = Variable(i)
222+
# for (prog, _), dst in relevant_dfta.rules.copy().items():
223+
# if isinstance(prog, Variable) and prog.no == j:
224+
# relevant_dfta.rules[(old, ())] = dst
225+
# added.add((old, ()))
226+
# relevant_dfta.refresh_reversed_rules()
227+
# n = relevant_dfta.trees_until_size(max_size)
228+
# # Delete them now that they have been used
229+
# for x in added:
230+
# del relevant_dfta.rules[x]
231+
# relevant_dfta.refresh_reversed_rules()
242232
# ==================================
243233

244-
return relevant_dfta, n
234+
return relevant_dfta, 1

grape/cli/prune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main():
118118
evaluator,
119119
manager,
120120
args.size,
121-
target_type,
121+
None,
122122
base_grammar,
123123
)
124124
type_req = type_request_from_specialized(reduced_grammar, dsl)

grape/pruning/obs_equiv_pruner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __get_base_grammar__(
7373
grammar = grammar_by_saturation(
7474
dsl,
7575
type_req,
76-
[commutativity_constraint(dsl, commutatives, type_req)],
76+
[commutativity_constraint(commutatives)],
7777
)
7878
else:
7979
base_grammar = base_dfta
@@ -116,6 +116,8 @@ def prune(
116116
base_grammar,
117117
type_req,
118118
)
119+
old_finals = grammar.finals.copy()
120+
grammar.finals = set(grammar.all_states)
119121
enum_ntrees = grammar.trees_until_size(max_size)
120122
base_ntrees = sum(base_expected_trees.values())
121123

@@ -170,9 +172,9 @@ def estimate_total(size: int) -> tuple[int, float]:
170172
pbar.update(n)
171173
pbar.close()
172174
evaluator.free_memory()
173-
reduced_grammar, t = grammar_from_memory(
174-
enumerator.memory, type_req, grammar.finals
175-
)
175+
grammar.finals = old_finals
176+
reduced_grammar, t = grammar_from_memory(enumerator.memory, type_req, old_finals)
177+
t = reduced_grammar.trees_until_size(max_size)
176178
print(f"at size {max_size} programs (after graping): {t:.2e}")
177179
print(
178180
"\tmethod: ratio no graping | ratio base | ratio graped",

tests/automaton/test_loop_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def signed_masking(n: int, mask: int = MAXI) -> int:
4747
return n & mask if n > 0 else -(-n & mask)
4848

4949

50-
sample_dict = {"int": lambda: signed_masking(random.randint(-MAXI, MAXI), MAXI)}
50+
sample_dict = {
51+
"int": lambda: signed_masking(random.randint(-MAXI, MAXI), MAXI),
52+
"bool": lambda: random.uniform(0, 1) > 0.5,
53+
}
5154

5255
inputs = sample_inputs(50, sample_dict)
5356

@@ -71,8 +74,7 @@ def signed_masking(n: int, mask: int = MAXI) -> int:
7174
max_size = 5
7275
manager = EquivalenceClassManager()
7376
evaluator = Evaluator(dsl, inputs, {}, set())
74-
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
75-
tr = "int->int"
77+
tr = "int->none"
7678
saturated = grammar_by_saturation(dsl, tr)
7779

7880

@@ -118,6 +120,7 @@ def comp_by_enum(grammars: list, tr: str, max_size: int):
118120

119121
@pytest.mark.parametrize("algo", algorithms)
120122
def test_same_size(algo: LoopingAlgorithm):
123+
out = prune(dsl, evaluator, manager, max_size=max_size)
121124
new_out = add_loops(out, dsl, algo)
122125
spec_out = respecialize(
123126
new_out, tr, type_request_from_specialized(new_out, dsl), dsl
@@ -127,6 +130,7 @@ def test_same_size(algo: LoopingAlgorithm):
127130

128131
@pytest.mark.parametrize("algo", algorithms)
129132
def test_next_size(algo: LoopingAlgorithm):
133+
out = prune(dsl, evaluator, manager, max_size=max_size)
130134
new_out = add_loops(out, dsl, algo)
131135
spec_out = respecialize(
132136
new_out, tr, type_request_from_specialized(new_out, dsl), dsl

tests/cli/test_prune.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def signed_masking(n: int, mask: int = MAXI) -> int:
6868
),
6969
}
7070
)
71-
evaluator = Evaluator(dsl, inputs, {}, set())
7271
max_size = 5
7372
algorithms = [LoopingAlgorithm.OBSERVATIONAL_EQUIVALENCE, LoopingAlgorithm.GRAPE]
7473

@@ -121,26 +120,27 @@ def comp_by_enum(grammars: list, tr: str, max_size: int):
121120

122121
def test_prune():
123122
manager = EquivalenceClassManager()
124-
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
125-
tr = "int->int"
123+
evaluator = Evaluator(dsl, inputs, {}, set())
124+
out = prune(dsl, evaluator, manager, max_size=max_size)
125+
tr = "int->none"
126126
g = grammar_by_saturation(dsl, tr)
127127
spec_out = respecialize(out, tr, type_request_from_specialized(out, dsl), dsl)
128128
comp_by_enum([spec_out, g], tr, max_size)
129129

130130

131131
def test_incremental_same_size():
132132
manager = EquivalenceClassManager()
133-
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
134-
incremental = prune(
135-
dsl, evaluator, manager, max_size=max_size, rtype="int", base_grammar=out
136-
)
133+
evaluator = Evaluator(dsl, inputs, {}, set())
134+
out = prune(dsl, evaluator, manager, max_size=max_size)
135+
incremental = prune(dsl, evaluator, manager, max_size=max_size, base_grammar=out)
137136
assert out.rules == incremental.rules
138137
assert out.finals == incremental.finals
139138

140139

141140
@pytest.mark.parametrize("algo", algorithms)
142141
def test_incremental_same_size_with_loops(algo: LoopingAlgorithm):
143142
manager = EquivalenceClassManager()
143+
evaluator = Evaluator(dsl, inputs, {}, set())
144144
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
145145
out = add_loops(out, dsl, algo)
146146

@@ -157,14 +157,14 @@ def test_incremental_same_size_with_loops(algo: LoopingAlgorithm):
157157
def test_incremental_next_size(algo: LoopingAlgorithm):
158158
manager = EquivalenceClassManager()
159159
evaluator = Evaluator(dsl, inputs, {}, set())
160-
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
160+
out = prune(dsl, evaluator, manager, max_size=max_size)
161161
out = add_loops(out, dsl, algo)
162162
evaluator.free_memory()
163163
incremental = prune(
164-
dsl, evaluator, manager, max_size=max_size + 1, rtype="int", base_grammar=out
164+
dsl, evaluator, manager, max_size=max_size + 1, base_grammar=out
165165
)
166166
evaluator.free_memory()
167-
direct = prune(dsl, evaluator, manager, max_size=max_size + 1, rtype="int")
167+
direct = prune(dsl, evaluator, manager, max_size=max_size + 1)
168168
comp_by_enum(
169169
[incremental, direct], type_request_from_specialized(direct, dsl), max_size + 1
170170
)
@@ -174,8 +174,8 @@ def test_is_superset():
174174
evaluator = Evaluator(dsl, inputs, {}, set())
175175

176176
manager = EquivalenceClassManager()
177-
out = prune(dsl, evaluator, manager, max_size=max_size, rtype="int")
178-
tr = "int->int"
177+
out = prune(dsl, evaluator, manager, max_size=max_size)
178+
tr = "int->none"
179179
base = grammar_by_saturation(dsl, tr)
180180
evaluator = Evaluator(dsl, inputs, {}, set())
181181
ebase = Enumerator(base)

0 commit comments

Comments
 (0)