Skip to content

Commit 8987f88

Browse files
moved to i64 for rational expression solver because some nat expr were to big
1 parent 19362ea commit 8987f88

File tree

14 files changed

+126
-86
lines changed

14 files changed

+126
-86
lines changed

examples/blocking.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
use std::time::Duration;
22

3-
use egg::{AstSize, RecExpr, Runner, SimpleScheduler};
3+
use egg::{AstSize, Runner, SimpleScheduler};
44

55
use eggshell::eqsat::hooks;
6-
use eggshell::rewrite_system::rise::{self, PrettyPrint, Rise, RiseRuleset};
7-
use eggshell::rewrite_system::rise::{BLOCKING_GOAL, MM, SPLIT_GUIDE};
6+
use eggshell::rewrite_system::rise::{self, BLOCKING_GOAL, MM, RiseRuleset, SPLIT_GUIDE};
87
use eggshell::sketch;
8+
use eggshell::utils;
99

1010
fn main() {
11-
let mm: RecExpr<Rise> = MM.parse().unwrap();
12-
let split_guide: RecExpr<Rise> = SPLIT_GUIDE.parse().unwrap();
11+
let mm = rise::canon_nat(&MM.parse().unwrap());
12+
let split_guide = rise::canon_nat(&SPLIT_GUIDE.parse().unwrap());
13+
let blocking_goal = rise::canon_nat(&BLOCKING_GOAL.parse().unwrap());
1314

1415
let runner_1 = Runner::default()
1516
.with_expr(&mm)
1617
.with_iter_limit(6)
17-
.with_time_limit(Duration::from_secs(30))
18+
.with_time_limit(Duration::from_secs(60))
1819
.with_node_limit(1_000_000)
1920
.with_scheduler(SimpleScheduler)
2021
.with_hook(hooks::targe_hook(split_guide.clone()))
@@ -24,32 +25,43 @@ fn main() {
2425
println!("{}", runner_1.report());
2526

2627
let root_mm = runner_1.egraph.find(runner_1.roots[0]);
27-
let split_guide_sketch = rise::sketchify(SPLIT_GUIDE, true);
28-
let (_, sketch_extracted_split_guide) =
29-
sketch::eclass_extract(&split_guide_sketch, AstSize, &runner_1.egraph, root_mm).unwrap();
30-
31-
println!("\nGuide Ground Truth");
32-
split_guide.pp(false);
33-
println!("\nSketch Extracted:");
34-
sketch_extracted_split_guide.pp(false);
35-
36-
assert_eq!(
37-
None,
38-
eggshell::utils::find_diff(&sketch_extracted_split_guide, &split_guide)
28+
let split_guide_sketch = rise::sketchify(&split_guide, true);
29+
let sketch_extracted_split_guide = rise::canon_nat(
30+
&sketch::eclass_extract(&split_guide_sketch, AstSize, &runner_1.egraph, root_mm)
31+
.unwrap()
32+
.1,
3933
);
40-
assert_eq!(root_mm, runner_1.egraph.lookup_expr(&split_guide).unwrap());
34+
assert!(utils::find_diff(&sketch_extracted_split_guide, &split_guide).is_none());
35+
36+
// println!("\nGuide Ground Truth");
37+
// split_guide.pp(false);
38+
// println!("\nSketch Extracted:");
39+
// sketch_extracted_split_guide.pp(false);
40+
41+
// assert_eq!(root_mm, runner_1.egraph.lookup_expr(&split_guide).unwrap());
4142

42-
let blocking_goal: RecExpr<Rise> = BLOCKING_GOAL.parse().unwrap();
4343
let runner_2 = Runner::default()
4444
.with_expr(&split_guide)
45-
.with_iter_limit(6)
45+
.with_iter_limit(8)
46+
.with_time_limit(Duration::from_secs(60))
47+
.with_node_limit(1_000_000)
4648
.with_scheduler(SimpleScheduler)
4749
.with_hook(hooks::targe_hook(blocking_goal.clone()))
50+
.with_hook(hooks::printer_hook)
4851
.run(&rise::rules(RiseRuleset::MM));
4952

53+
println!("{}", runner_2.report());
54+
5055
let root_guide = runner_2.egraph.find(runner_2.roots[0]);
51-
assert_eq!(
52-
root_guide,
53-
runner_2.egraph.lookup_expr(&blocking_goal).unwrap()
56+
let blocking_goal_sketch = rise::sketchify(&blocking_goal, true);
57+
let sketch_extracted_blocking_goal = rise::canon_nat(
58+
&sketch::eclass_extract(&blocking_goal_sketch, AstSize, &runner_2.egraph, root_guide)
59+
.unwrap()
60+
.1,
5461
);
62+
assert!(utils::find_diff(&sketch_extracted_blocking_goal, &blocking_goal).is_none());
63+
// assert_eq!(
64+
// root_guide,
65+
// runner_2.egraph.lookup_expr(&blocking_goal).unwrap()
66+
// );
5567
}

src/rewrite_system/rise/analysis.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use egg::{Analysis, DidMerge, EGraph, Id, Language, RecExpr};
1+
use egg::{Analysis, DidMerge, EGraph, Language, RecExpr};
22
use hashbrown::HashSet;
33

44
use super::nat::try_simplify;
@@ -47,28 +47,29 @@ impl Analysis<Rise> for RiseAnalysis {
4747
fn merge(&mut self, to: &mut AnalysisData, from: AnalysisData) -> DidMerge {
4848
let before_len = to.free.len();
4949
to.free.extend(from.free);
50-
let free_merge = DidMerge(before_len != to.free.len(), true);
50+
let free_merge = before_len != to.free.len();
5151

5252
let beta_merge = if !from.beta_extract.is_empty()
5353
&& (to.beta_extract.is_empty() || to.beta_extract.len() > from.beta_extract.len())
5454
{
5555
to.beta_extract = from.beta_extract;
56-
DidMerge(true, true)
56+
true
5757
} else {
58-
DidMerge(false, true) // TODO: more precise second bool
58+
false
5959
};
6060

6161
// let nat_merge =
6262
// egg::merge_option(&mut to.simple_nat, from.simple_nat, |to_nat, from_nat| {
6363
// if to_nat.len() > from_nat.len() {
6464
// *to_nat = from_nat;
65-
// DidMerge(true, true)
65+
// true
6666
// } else {
67-
// DidMerge(false, true)
67+
// false
6868
// }
6969
// });
7070

71-
free_merge | beta_merge // | nat_merge
71+
// TODO: more precise second bool
72+
DidMerge(free_merge || beta_merge, true)
7273
}
7374

7475
fn make(egraph: &mut EGraph<Rise, RiseAnalysis>, enode: &Rise) -> AnalysisData {
@@ -100,7 +101,8 @@ impl Analysis<Rise> for RiseAnalysis {
100101
(RecExpr::default(), None)
101102
} else {
102103
let expr = enode.join_recexprs(|id| egraph[id].data.beta_extract.as_ref());
103-
let simple_nat = try_simplify(&expr).ok();
104+
let simple_nat = enode.is_nat().then(|| try_simplify(&expr).ok()).flatten();
105+
104106
(expr, simple_nat)
105107
};
106108

src/rewrite_system/rise/func.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl<A: Applier<Rise, RiseAnalysis>> Applier<Rise, RiseAnalysis> for VectorizeSc
9999
}
100100
}
101101

102-
fn extracted_int(expr: &RecExpr<Rise>) -> i32 {
102+
fn extracted_int(expr: &RecExpr<Rise>) -> i64 {
103103
if let Rise::Integer(i) = expr[0.into()] {
104104
return i;
105105
}
@@ -110,7 +110,7 @@ fn extracted_int(expr: &RecExpr<Rise>) -> i32 {
110110
#[expect(clippy::too_many_lines)]
111111
fn vec_expr(
112112
expr: &RecExpr<Rise>,
113-
n: i32,
113+
n: i64,
114114
v_env: HashSet<DBIndex>,
115115
type_of_id: Id,
116116
) -> Option<(RecExpr<Rise>, Id, Id)> {
@@ -249,7 +249,7 @@ fn vec_expr(
249249
// }
250250
}
251251

252-
fn vec_ty(expr: &RecExpr<Rise>, n: i32, id: Id) -> Option<RecExpr<Rise>> {
252+
fn vec_ty(expr: &RecExpr<Rise>, n: i64, id: Id) -> Option<RecExpr<Rise>> {
253253
match expr[id] {
254254
Rise::F32 => {
255255
let mut vec_ty = RecExpr::default();

src/rewrite_system/rise/lang.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ egg::define_language! {
7777
// "sig" = Sigma([Id; 3]),
7878
// "phi" = Phi([Id; 3]),
7979

80-
Integer(i32),
81-
Float(NotNan<f32>),
80+
Integer(i64),
81+
Float(NotNan<f64>),
8282
// Double(f64),
8383
// Symbol(Symbol),
8484
}

src/rewrite_system/rise/mod.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ pub use analysis::RiseAnalysis;
1717
pub use lang::Rise;
1818
pub use pp::PrettyPrint;
1919

20-
use crate::sketch::{Sketch, SketchLang};
20+
use crate::{
21+
rewrite_system::rise::nat::try_simplify,
22+
sketch::{Sketch, SketchLang},
23+
};
2124

2225
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
2326
pub enum RiseRuleset {
@@ -45,9 +48,8 @@ fn add_expr(to: &mut RecExpr<Rise>, e: RecExpr<Rise>) -> Id {
4548
to.root()
4649
}
4750

48-
#[expect(clippy::missing_panics_doc)]
4951
#[must_use]
50-
pub fn sketchify(term: &str, sketchify_nat_expr: bool) -> Sketch<Rise> {
52+
pub fn sketchify(expr: &RecExpr<Rise>, sketchify_nat_expr: bool) -> Sketch<Rise> {
5153
fn rec(
5254
expr: &RecExpr<Rise>,
5355
id: Id,
@@ -67,12 +69,29 @@ pub fn sketchify(term: &str, sketchify_nat_expr: bool) -> Sketch<Rise> {
6769
}
6870
}
6971
}
70-
let expr: RecExpr<Rise> = term.parse().unwrap();
7172
let mut sketch = RecExpr::default();
72-
rec(&expr, expr.root(), sketchify_nat_expr, &mut sketch);
73+
rec(expr, expr.root(), sketchify_nat_expr, &mut sketch);
7374
sketch
7475
}
7576

77+
#[must_use]
78+
pub fn canon_nat(expr: &RecExpr<Rise>) -> RecExpr<Rise> {
79+
fn rec(expr: &RecExpr<Rise>, id: Id, canon_expr: &mut RecExpr<Rise>) -> Id {
80+
let node = &expr[id];
81+
if let Ok(canon_nat_expr) = try_simplify(&node.build_recexpr(|i| expr[i].clone())) {
82+
add_expr(canon_expr, canon_nat_expr)
83+
} else {
84+
let new_node = node
85+
.clone()
86+
.map_children(|c_id| rec(expr, c_id, canon_expr));
87+
canon_expr.add(new_node)
88+
}
89+
}
90+
let mut canon_expr = RecExpr::default();
91+
rec(expr, expr.root(), &mut canon_expr);
92+
canon_expr
93+
}
94+
7695
// START TERM
7796
pub const MM: &str = "(typeOf (natLam (typeOf (natLam (typeOf (natLam (typeOf (lam (typeOf (lam (typeOf (app (typeOf (app (typeOf map (fun (fun (arrT %n0 f32) (arrT %n1 f32)) (fun (arrT %n2 (arrT %n0 f32)) (arrT %n2 (arrT %n1 f32))))) (typeOf (lam (typeOf (app (typeOf (app (typeOf map (fun (fun (arrT %n0 f32) f32) (fun (arrT %n1 (arrT %n0 f32)) (arrT %n1 f32)))) (typeOf (lam (typeOf (app (typeOf (app (typeOf (app (typeOf reduce (fun (fun f32 (fun f32 f32)) (fun f32 (fun (arrT %n0 f32) f32)))) (typeOf add (fun f32 (fun f32 f32)))) (fun f32 (fun (arrT %n0 f32) f32))) (typeOf 0.0 f32)) (fun (arrT %n0 f32) f32)) (typeOf (app (typeOf (app (typeOf map (fun (fun (pairT f32 f32) f32) (fun (arrT %n0 (pairT f32 f32)) (arrT %n0 f32)))) (typeOf (lam (typeOf (app (typeOf (app (typeOf mul (fun f32 (fun f32 f32))) (typeOf (app (typeOf fst (fun (pairT f32 f32) f32)) (typeOf %e0 (pairT f32 f32))) f32)) (fun f32 f32)) (typeOf (app (typeOf snd (fun (pairT f32 f32) f32)) (typeOf %e0 (pairT f32 f32))) f32)) f32)) (fun (pairT f32 f32) f32))) (fun (arrT %n0 (pairT f32 f32)) (arrT %n0 f32))) (typeOf (app (typeOf (app (typeOf zip (fun (arrT %n0 f32) (fun (arrT %n0 f32) (arrT %n0 (pairT f32 f32))))) (typeOf %e1 (arrT %n0 f32))) (fun (arrT %n0 f32) (arrT %n0 (pairT f32 f32)))) (typeOf %e0 (arrT %n0 f32))) (arrT %n0 (pairT f32 f32)))) (arrT %n0 f32))) f32)) (fun (arrT %n0 f32) f32))) (fun (arrT %n1 (arrT %n0 f32)) (arrT %n1 f32))) (typeOf (app (typeOf transpose (fun (arrT %n0 (arrT %n1 f32)) (arrT %n1 (arrT %n0 f32)))) (typeOf %e1 (arrT %n0 (arrT %n1 f32)))) (arrT %n1 (arrT %n0 f32)))) (arrT %n1 f32))) (fun (arrT %n0 f32) (arrT %n1 f32)))) (fun (arrT %n2 (arrT %n0 f32)) (arrT %n2 (arrT %n1 f32)))) (typeOf %e1 (arrT %n2 (arrT %n0 f32)))) (arrT %n2 (arrT %n1 f32)))) (fun (arrT %n0 (arrT %n1 f32)) (arrT %n2 (arrT %n1 f32))))) (fun (arrT %n2 (arrT %n0 f32)) (fun (arrT %n0 (arrT %n1 f32)) (arrT %n2 (arrT %n1 f32)))))) (natFun (fun (arrT %n2 (arrT %n0 f32)) (fun (arrT %n0 (arrT %n1 f32)) (arrT %n2 (arrT %n1 f32))))))) (natFun (natFun (fun (arrT %n2 (arrT %n0 f32)) (fun (arrT %n0 (arrT %n1 f32)) (arrT %n2 (arrT %n1 f32)))))))) (natFun (natFun (natFun (fun (arrT %n2 (arrT %n0 f32)) (fun (arrT %n0 (arrT %n1 f32)) (arrT %n2 (arrT %n1 f32))))))))";
7897
// GUIDES

src/rewrite_system/rise/nat/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use polynomial::Polynomial;
1717
pub use applier::{ComputeNat, ComputeNatCheck};
1818
pub use rational::RationalFunction;
1919

20+
type Ratio = num::rational::Ratio<i64>;
21+
2022
// ============================================================================
2123
// Helper Functions
2224
// ============================================================================

src/rewrite_system/rise/nat/monomial.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@ use crate::rewrite_system::rise::DBIndex;
1010
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
1111
pub struct Monomial {
1212
// Map from dbindex to exponent (can be negative for rational expressions)
13-
variables: BTreeMap<DBIndex, i32>,
13+
variables: BTreeMap<DBIndex, i64>,
1414
}
1515

1616
impl Monomial {
1717
pub fn new() -> Self {
1818
Self::default()
1919
}
2020

21-
pub fn variables(&self) -> &BTreeMap<DBIndex, i32> {
21+
pub fn variables(&self) -> &BTreeMap<DBIndex, i64> {
2222
&self.variables
2323
}
2424

25-
pub fn with_var(mut self, dbindex: DBIndex, exponent: i32) -> Self {
25+
pub fn with_var(mut self, dbindex: DBIndex, exponent: i64) -> Self {
2626
if exponent != 0 {
2727
self.variables.insert(dbindex, exponent);
2828
}

src/rewrite_system/rise/nat/polynomial/from.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use egg::RecExpr;
2-
use num::rational::Ratio;
32
use num_traits::One;
43

5-
use super::{Monomial, Polynomial, Rise};
4+
use super::{Monomial, Polynomial, Ratio, Rise};
65
use crate::rewrite_system::rise::DBIndex;
76

87
// ============================================================================
@@ -64,14 +63,14 @@ impl From<Polynomial> for RecExpr<Rise> {
6463
// ============================================================================
6564

6665
/// Create a `Polynomial` from an integer constant
67-
impl From<i32> for Polynomial {
68-
fn from(n: i32) -> Self {
66+
impl From<i64> for Polynomial {
67+
fn from(n: i64) -> Self {
6968
Self::new().add_term(n.into(), Monomial::new())
7069
}
7170
}
7271
/// Create a `Polynomial` from an integer constant
73-
impl From<Ratio<i32>> for Polynomial {
74-
fn from(r: Ratio<i32>) -> Self {
72+
impl From<Ratio> for Polynomial {
73+
fn from(r: Ratio) -> Self {
7574
Self::new().add_term(r, Monomial::new())
7675
}
7776
}

0 commit comments

Comments
 (0)