Skip to content

Commit 668fcd0

Browse files
committed
Support builtins during type inference
1 parent fd7e99c commit 668fcd0

File tree

4 files changed

+143
-10
lines changed

4 files changed

+143
-10
lines changed

eval/builtins.go

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@ package eval
33
import (
44
"fmt"
55
"math"
6+
"strings"
67

78
"github.com/Victorystick/scrapscript/types"
89
)
910

10-
func bindBuiltIns(reg *types.Registry) Variables {
11+
func bindBuiltIns(reg *types.Registry) (types.TypeScope, Variables) {
12+
var scope types.TypeScope
1113
var builtIns = make(Variables)
1214

1315
define := func(name string, typ types.TypeRef, val Func) {
1416
builtIns[name] = BuiltInFunc{name, typ, val}
17+
scope = scope.Bind(name, typ)
1518
}
1619

1720
// Built-in types
@@ -27,8 +30,16 @@ func bindBuiltIns(reg *types.Registry) Variables {
2730
aToB := reg.Func(a, b)
2831
aList := reg.List(a)
2932
bList := reg.List(b)
33+
textList := reg.List(types.TextRef)
3034

3135
// Lists
36+
define("list/length", reg.Func(aList, types.IntRef), func(val Value) (Value, error) {
37+
ls, ok := val.(List)
38+
if !ok {
39+
return nil, fmt.Errorf("expected list, but got %T", val)
40+
}
41+
return Int(len(ls.elements)), nil
42+
})
3243
define("list/map", reg.Func(aToB, reg.Func(aList, bList)), func(val Value) (Value, error) {
3344
fn := Callable(val)
3445
if fn == nil {
@@ -56,8 +67,8 @@ func bindBuiltIns(reg *types.Registry) Variables {
5667
},
5768
}, nil
5869
})
59-
// TODO: type
60-
define("list/fold", types.NeverRef, func(acc Value) (Value, error) {
70+
accum := reg.Func(a, reg.Func(b, a))
71+
define("list/fold", reg.Func(a, reg.Func(accum, reg.Func(bList, a))), func(acc Value) (Value, error) {
6172
source := "list/fold " + acc.String()
6273
return ScriptFunc{
6374
source: source,
@@ -96,6 +107,22 @@ func bindBuiltIns(reg *types.Registry) Variables {
96107
},
97108
}, nil
98109
})
110+
define("list/repeat", reg.Func(types.IntRef, reg.Func(a, aList)), func(val Value) (Value, error) {
111+
n, ok := val.(Int)
112+
if !ok {
113+
return nil, fmt.Errorf("expected int, but got %T", val)
114+
}
115+
return ScriptFunc{
116+
source: "list/repeat " + val.String(),
117+
fn: func(val Value) (v Value, err error) {
118+
elems := make([]Value, int(n))
119+
for i := range elems {
120+
elems[i] = val
121+
}
122+
return List{val.Type(), elems}, nil
123+
},
124+
}, nil
125+
})
99126

100127
// Text
101128
define("text/length", reg.Func(types.TextRef, types.IntRef), func(val Value) (Value, error) {
@@ -105,6 +132,46 @@ func bindBuiltIns(reg *types.Registry) Variables {
105132
}
106133
return Int(len(text)), nil
107134
})
135+
define("text/repeat", reg.Func(types.IntRef, reg.Func(types.TextRef, types.TextRef)), func(val Value) (Value, error) {
136+
n, ok := val.(Int)
137+
if !ok {
138+
return nil, fmt.Errorf("expected int, but got %T", val)
139+
}
140+
return ScriptFunc{
141+
source: "text/repeat " + val.String(),
142+
fn: func(val Value) (v Value, err error) {
143+
text, ok := val.(Text)
144+
if !ok {
145+
return nil, fmt.Errorf("expected text, but got %T", val)
146+
}
147+
return Text(strings.Repeat(string(text), int(n))), nil
148+
},
149+
}, nil
150+
})
151+
define("text/join", reg.Func(types.TextRef, reg.Func(textList, types.TextRef)), func(val Value) (Value, error) {
152+
sep, ok := val.(Text)
153+
if !ok {
154+
return nil, fmt.Errorf("expected text, but got %T", val)
155+
}
156+
return ScriptFunc{
157+
source: "text/join " + val.String(),
158+
fn: func(val Value) (v Value, err error) {
159+
ls, ok := val.(List)
160+
if !ok {
161+
return nil, fmt.Errorf("expected list, but got %T", val)
162+
}
163+
elems := make([]string, len(ls.elements))
164+
for i, v := range ls.elements {
165+
text, ok := v.(Text)
166+
if !ok {
167+
return nil, fmt.Errorf("expected text, but got %T", v)
168+
}
169+
elems[i] = string(text)
170+
}
171+
return Text(strings.Join(elems, string(sep))), nil
172+
},
173+
}, nil
174+
})
108175

109176
// int -> float
110177
define("to-float", reg.Func(types.IntRef, types.FloatRef), func(val Value) (Value, error) {
@@ -134,7 +201,7 @@ func bindBuiltIns(reg *types.Registry) Variables {
134201
return nil, fmt.Errorf("cannot bytes/from-utf8-text on %T", val)
135202
})
136203

137-
return builtIns
204+
return scope, builtIns
138205
}
139206

140207
func roundFunc(round func(float64) float64) Func {

eval/env.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ type Scrap struct {
2020
type Sha256Hash = [32]byte
2121

2222
type Environment struct {
23-
fetcher yards.Fetcher
24-
reg types.Registry
23+
fetcher yards.Fetcher
24+
reg types.Registry
25+
// The TypeScope and Variables match each other's contents.
26+
// One is used for type inference, the other for evaluation.
27+
typeScope types.TypeScope
2528
vars Variables
2629
scraps map[Sha256Hash]*Scrap
2730
evalImport EvalImport
@@ -30,7 +33,9 @@ type Environment struct {
3033

3134
func NewEnvironment() *Environment {
3235
env := &Environment{}
33-
env.vars = bindBuiltIns(&env.reg)
36+
typeScope, vars := bindBuiltIns(&env.reg)
37+
env.typeScope = typeScope
38+
env.vars = vars
3439
env.scraps = make(map[Sha256Hash]*Scrap)
3540
env.evalImport = func(algo string, hash []byte) (Value, error) {
3641
scrap, err := env.fetch(algo, hash)
@@ -104,9 +109,7 @@ func (e *Environment) Eval(scrap *Scrap) (Value, error) {
104109

105110
func (e *Environment) infer(scrap *Scrap) (types.TypeRef, error) {
106111
if scrap.typ == types.NeverRef {
107-
// TODO: Add a complete type scope.
108-
scope := types.DefaultScope(&e.reg)
109-
ref, err := types.Infer(&e.reg, scope, scrap.expr, e.inferImport)
112+
ref, err := types.Infer(&e.reg, e.typeScope, scrap.expr, e.inferImport)
110113
scrap.typ = ref
111114
return ref, err
112115
}

eval/env_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package eval
2+
3+
import "testing"
4+
5+
func TestInferBuiltin(t *testing.T) {
6+
examples := []struct {
7+
source string
8+
result string
9+
}{
10+
// numeric conversions
11+
{`round`, `float -> int`},
12+
{`ceil`, `float -> int`},
13+
{`floor`, `float -> int`},
14+
{`to-float`, `int -> float`},
15+
16+
// byte <-> text conversion
17+
{`bytes/to-utf8-text`, `bytes -> text`},
18+
{`bytes/from-utf8-text`, `text -> bytes`},
19+
20+
// list
21+
{`list/length`, `list $0 -> int`},
22+
{`list/map`, `($0 -> $1) -> list $0 -> list $1`},
23+
{`list/map (a -> a + 1)`, `list int -> list int`},
24+
// {`list/fold`, `($0 -> $1) -> list $0 -> list $1`},
25+
// {`list/repeat`, `int -> list a -> text`},
26+
27+
// text
28+
{`text/length`, `text -> int`},
29+
{`text/repeat`, `int -> text -> text`},
30+
// {`text/join`, `text -> list text -> text`},
31+
32+
{`list/fold 0 (a -> b -> a + text/length b)`, `list text -> int`},
33+
{`list/fold 0 (a -> b -> a + text/length b) ["hey", "beautiful"]`, `int`},
34+
}
35+
36+
for _, ex := range examples {
37+
env := NewEnvironment()
38+
scrap, err := env.Read([]byte(ex.source))
39+
40+
if err != nil {
41+
t.Error(err)
42+
continue
43+
}
44+
45+
res, err := env.Infer(scrap)
46+
47+
if err != nil {
48+
t.Error(err)
49+
} else {
50+
if res != ex.result {
51+
t.Errorf("Invalid type for '%s'\n expected: %s\n got: %s", ex.source, ex.result, res)
52+
}
53+
}
54+
}
55+
}

eval/eval_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ var expressions = []struct {
6666
{`3 |> a -> b -> a`, `b -> a`},
6767

6868
{`#true #false`, `<type>`}, // TODO: should be `#true #false`
69+
70+
{`list/repeat 2`, `list/repeat 2`},
71+
{`text/repeat 3`, `text/repeat 3`},
72+
{`text/join " "`, `text/join " "`},
73+
74+
{`"hi " ++ text/repeat 3 "a" ++ "ron"`, `"hi aaaron"`},
75+
{`"yo" |> list/repeat 2 |> text/join " "`, `"yo yo"`},
6976
}
7077

7178
var failures = []struct {
@@ -145,6 +152,7 @@ var exp2str = []struct{ source, result string }{
145152
{`list/fold 0 (a -> b -> a + b) []`, `0`},
146153
{`list/fold 0 (a -> b -> a + b)`, `list/fold 0 a -> b -> a + b`},
147154
{`list/fold 0 (a -> b -> a + b) [1, 2]`, `3`},
155+
{`list/fold 0 (a -> b -> a + text/length b) ["hey", "beautiful"]`, `12`},
148156

149157
{`[ 4 + 2, 5 - 1, ]`, "[ 6, 4 ]"},
150158
{`[ 1, 4 ] |> | [1,3] -> "three" |[_,4] -> "four"`, `"four"`},

0 commit comments

Comments
 (0)