Skip to content

Commit bb185fb

Browse files
committed
Fix equality check
1 parent d834818 commit bb185fb

File tree

1 file changed

+78
-10
lines changed

1 file changed

+78
-10
lines changed

src/typechecking.pr

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -464,37 +464,67 @@ export def copy(pattern: &Pattern) -> &Pattern {
464464
return res
465465
}
466466

467+
type EqEnv = struct {
468+
scope: &scope::Scope
469+
typedefs_a: &SMap(int)
470+
typedefs_b: &SMap(int)
471+
counter: int
472+
}
473+
474+
def make_eq_env(scpe: &scope::Scope) -> &EqEnv {
475+
return [
476+
scope = scpe,
477+
typedefs_a = map::make(int),
478+
typedefs_b = map::make(int),
479+
counter = 0
480+
] !&EqEnv
481+
}
482+
467483
export def == (this: &Pattern, other: &Pattern) -> bool {
484+
return equals(this, other)
485+
}
486+
487+
export def equals(this: &Pattern, other: &Pattern, env: &EqEnv = null) -> bool {
468488
if this.kind != other.kind {
469489
return false
470490
}
471491
switch this.kind {
472492
case TypeKind::TYPE_DEF
473-
if this.tpe and other.tpe {
493+
if env {
494+
env.typedefs_a(this.name) = env.counter
495+
env.typedefs_b(other.name) = env.counter
496+
env.counter += 1
497+
if this.tpe and other.tpe {
498+
return parameter_equals_single(this.tpe, other.tpe, env)
499+
}
500+
return true
501+
} else if this.tpe and other.tpe {
474502
return this.tpe == other.tpe
475503
} else if this.name and other.name {
476-
return true
504+
return this.name == other.name
477505
} else {
478-
assert
506+
return false
479507
}
480508
case TypeKind::ARRAY, TypeKind::POINTER, TypeKind::REFERENCE, TypeKind::WEAK_REF
481-
return this.kw == other.kw and this.tpe == other.tpe
509+
return this.kw == other.kw and parameter_equals_single(this.tpe, other.tpe, env)
482510
case TypeKind::STATIC_ARRAY
483-
return this.length == other.length and this.kw == other.kw and this.tpe == other.tpe
511+
return this.length == other.length and this.kw == other.kw and parameter_equals_single(this.tpe, other.tpe, env)
484512
case TypeKind::FUNCTION, TypeKind::CLOSURE
485513
if this.parameter_t.length() != other.parameter_t.length() {
486514
return false
487515
}
488516
for var i in 0..this.parameter_t.length() {
489-
if this.parameter_t(i) != other.parameter_t(i) {
517+
let param_a = this.parameter_t(i)
518+
let param_b = other.parameter_t(i)
519+
if param_a.name != param_b.name or not parameter_equals_single(param_a.tpe, param_b.tpe, env) {
490520
return false
491521
}
492522
}
493523
if this.return_t.length() != other.return_t.length() {
494524
return false
495525
}
496526
for var i in 0..this.return_t.length() {
497-
if this.return_t(i) != other.return_t(i) {
527+
if not parameter_equals_single(this.return_t(i), other.return_t(i), env) {
498528
return false
499529
}
500530
}
@@ -504,7 +534,7 @@ export def == (this: &Pattern, other: &Pattern) -> bool {
504534
return false
505535
}
506536
for var i in 0..this.return_t.length() {
507-
if this.return_t(i) != other.return_t(i) {
537+
if not parameter_equals_single(this.return_t(i), other.return_t(i), env) {
508538
return false
509539
}
510540
}
@@ -514,7 +544,7 @@ export def == (this: &Pattern, other: &Pattern) -> bool {
514544
return false
515545
}
516546
for var i in 0..this.fields.length() {
517-
if this.fields(i).name != other.fields(i).name or this.fields(i).tpe != other.fields(i).tpe {
547+
if this.fields(i).name != other.fields(i).name or not parameter_equals_single(this.fields(i).tpe, other.fields(i).tpe, env) {
518548
return false
519549
}
520550
}
@@ -3048,16 +3078,54 @@ export def parameter_equals(
30483078
if a.parameter_t.length != vector::length(param_b) {
30493079
return false
30503080
}
3081+
let env = make_eq_env(scpe)
30513082
for var i in 0..vector::length(param_b) {
30523083
let left = param_b(i)
30533084
let right = a.parameter_t(i)
3054-
if left != right {
3085+
let e = parameter_equals_single(left.tpe, right.tpe, env)
3086+
//print(generic_to_string(left.tpe), " == ", generic_to_string(right.tpe), " = ", e, "\n")
3087+
if not e {
30553088
return false
30563089
}
30573090
}
30583091
return true
30593092
}
30603093

3094+
def parameter_equals_single(a: &TypeRef, b: &TypeRef, env: &EqEnv) -> bool {
3095+
if not env { return a == b }
3096+
if is_type(a) and is_type(b) {
3097+
return a == b
3098+
} else if is_type(a) {
3099+
let tb = b.copy().resolve(env.scope, error_on_fail = false)
3100+
return tb == a.tpe
3101+
} else if is_type(b) {
3102+
let ta = a.copy().resolve(env.scope, error_on_fail = false)
3103+
return ta == b.tpe
3104+
} else if a.name and b.name {
3105+
if env.typedefs_a.contains(a.name) and env.typedefs_b.contains(b.name) {
3106+
return env.typedefs_a(a.name) == env.typedefs_b(b.name)
3107+
} else if env.typedefs_a.contains(a.name) or env.typedefs_b.contains(b.name) {
3108+
return false
3109+
} else {
3110+
return a.name == b.name
3111+
}
3112+
} else if a.name {
3113+
let ta_r = a.copy().resolve(env.scope, error_on_fail = false)
3114+
if ta_r { return ta_r == b.copy().resolve(env.scope, error_on_fail = false) }
3115+
return false
3116+
} else if b.name {
3117+
let tb_r = b.copy().resolve(env.scope, error_on_fail = false)
3118+
if tb_r { return tb_r == a.copy().resolve(env.scope, error_on_fail = false) }
3119+
return false
3120+
} else {
3121+
let pa = a.pattern
3122+
let pb = b.pattern
3123+
if not pa and not pb { return true }
3124+
if not pa or not pb { return false }
3125+
return equals(pa, pb, env)
3126+
}
3127+
}
3128+
30613129
export def overload_score(
30623130
a: FunctionDef, param_b: &Vector(NamedParameter), scpe: &scope::Scope,
30633131
positional: bool, partial: bool = false, impl: bool = true) -> int {

0 commit comments

Comments
 (0)