From 7fad190898d3a4784b5851da14f37ee258de40ca Mon Sep 17 00:00:00 2001 From: Dave Lucia Date: Thu, 5 Feb 2026 12:26:05 -0500 Subject: [PATCH 1/4] feat: lexer / parser in Elixir This reimplements a Lua 5.3 lexer and parser, entirely in Elixir. This is not useful yet outside of if we want better error messages for invalid AST. More follow ups on this to come --- lib/lua/ast/block.ex | 63 ++ lib/lua/ast/builder.ex | 570 ++++++++++ lib/lua/ast/expr.ex | 216 ++++ lib/lua/ast/meta.ex | 89 ++ lib/lua/ast/pretty_printer.ex | 467 +++++++++ lib/lua/ast/stmt.ex | 262 +++++ lib/lua/ast/walker.ex | 287 +++++ lib/lua/lexer.ex | 522 +++++++++ lib/lua/parser.ex | 1166 +++++++++++++++++++++ lib/lua/parser/error.ex | 344 ++++++ lib/lua/parser/pratt.ex | 130 +++ lib/lua/parser/recovery.ex | 208 ++++ test/lua/ast/builder_test.exs | 481 +++++++++ test/lua/ast/pretty_printer_test.exs | 425 ++++++++ test/lua/ast/walker_test.exs | 293 ++++++ test/lua/lexer_test.exs | 482 +++++++++ test/lua/parser/beautiful_errors_test.exs | 269 +++++ test/lua/parser/error_test.exs | 172 +++ test/lua/parser/expr_test.exs | 85 ++ test/lua/parser/precedence_test.exs | 429 ++++++++ test/lua/parser/stmt_test.exs | 591 +++++++++++ 21 files changed, 7551 insertions(+) create mode 100644 lib/lua/ast/block.ex create mode 100644 lib/lua/ast/builder.ex create mode 100644 lib/lua/ast/expr.ex create mode 100644 lib/lua/ast/meta.ex create mode 100644 lib/lua/ast/pretty_printer.ex create mode 100644 lib/lua/ast/stmt.ex create mode 100644 lib/lua/ast/walker.ex create mode 100644 lib/lua/lexer.ex create mode 100644 lib/lua/parser.ex create mode 100644 lib/lua/parser/error.ex create mode 100644 lib/lua/parser/pratt.ex create mode 100644 lib/lua/parser/recovery.ex create mode 100644 test/lua/ast/builder_test.exs create mode 100644 test/lua/ast/pretty_printer_test.exs create mode 100644 test/lua/ast/walker_test.exs create mode 100644 test/lua/lexer_test.exs create mode 100644 test/lua/parser/beautiful_errors_test.exs create mode 100644 test/lua/parser/error_test.exs create mode 100644 test/lua/parser/expr_test.exs create mode 100644 test/lua/parser/precedence_test.exs create mode 100644 test/lua/parser/stmt_test.exs diff --git a/lib/lua/ast/block.ex b/lib/lua/ast/block.ex new file mode 100644 index 0000000..1ab1db6 --- /dev/null +++ b/lib/lua/ast/block.ex @@ -0,0 +1,63 @@ +defmodule Lua.AST.Block do + @moduledoc """ + Represents a block of statements in Lua. + + A block is a sequence of statements that execute in order. + Blocks create a new scope for local variables. + """ + + alias Lua.AST.{Meta, Stmt} + + @type t :: %__MODULE__{ + stmts: [Stmt.t()], + meta: Meta.t() | nil + } + + defstruct stmts: [], meta: nil + + @doc """ + Creates a new Block. + + ## Examples + + iex> Lua.AST.Block.new([]) + %Lua.AST.Block{stmts: [], meta: nil} + + iex> Lua.AST.Block.new([], %Lua.AST.Meta{}) + %Lua.AST.Block{stmts: [], meta: %Lua.AST.Meta{start: nil, end: nil, metadata: %{}}} + """ + @spec new([Stmt.t()], Meta.t() | nil) :: t() + def new(stmts \\ [], meta \\ nil) do + %__MODULE__{stmts: stmts, meta: meta} + end +end + +defmodule Lua.AST.Chunk do + @moduledoc """ + Represents the top-level chunk (file or string) in Lua. + + A chunk is essentially a block that represents a complete unit of Lua code. + """ + + alias Lua.AST.{Meta, Block} + + @type t :: %__MODULE__{ + block: Block.t(), + meta: Meta.t() | nil + } + + defstruct [:block, :meta] + + @doc """ + Creates a new Chunk. + + ## Examples + + iex> Lua.AST.Chunk.new(%Lua.AST.Block{stmts: []}) + %Lua.AST.Chunk{block: %Lua.AST.Block{stmts: [], meta: nil}, meta: nil} + """ + @spec new(Block.t(), Meta.t() | nil) :: t() + def new(block, meta \\ nil) do + %__MODULE__{block: block, meta: meta} + end +end diff --git a/lib/lua/ast/builder.ex b/lib/lua/ast/builder.ex new file mode 100644 index 0000000..c3b8efe --- /dev/null +++ b/lib/lua/ast/builder.ex @@ -0,0 +1,570 @@ +defmodule Lua.AST.Builder do + @moduledoc """ + Helpers for programmatically constructing AST nodes. + + Provides a convenient API for building AST without manually + creating all the struct fields. Useful for: + - Code generation + - AST transformations + - Testing + - Metaprogramming with quote/unquote + + ## Examples + + import Lua.AST.Builder + + # Build a simple expression: 2 + 2 + binop(:add, number(2), number(2)) + + # Build a local assignment: local x = 42 + local(["x"], [number(42)]) + + # Build a function: function add(a, b) return a + b end + func_decl("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + """ + + alias Lua.AST.{Chunk, Block, Meta, Expr, Stmt} + + # Chunk and Block + + @doc """ + Creates a Chunk node. + + ## Examples + + chunk([local(["x"], [number(42)])]) + """ + @spec chunk([Stmt.t()], Meta.t() | nil) :: Chunk.t() + def chunk(stmts, meta \\ nil) do + %Chunk{ + block: block(stmts, meta), + meta: meta + } + end + + @doc """ + Creates a Block node. + + ## Examples + + block([ + local(["x"], [number(10)]), + assign([var("x")], [number(20)]) + ]) + """ + @spec block([Stmt.t()], Meta.t() | nil) :: Block.t() + def block(stmts, meta \\ nil) do + %Block{ + stmts: stmts, + meta: meta + } + end + + # Literal expressions + + @doc "Creates a nil literal" + @spec nil_lit(Meta.t() | nil) :: Expr.Nil.t() + def nil_lit(meta \\ nil), do: %Expr.Nil{meta: meta} + + @doc "Creates a boolean literal" + @spec bool(boolean(), Meta.t() | nil) :: Expr.Bool.t() + def bool(value, meta \\ nil), do: %Expr.Bool{value: value, meta: meta} + + @doc "Creates a number literal" + @spec number(number(), Meta.t() | nil) :: Expr.Number.t() + def number(value, meta \\ nil), do: %Expr.Number{value: value, meta: meta} + + @doc "Creates a string literal" + @spec string(String.t(), Meta.t() | nil) :: Expr.String.t() + def string(value, meta \\ nil), do: %Expr.String{value: value, meta: meta} + + @doc "Creates a vararg expression (...)" + @spec vararg(Meta.t() | nil) :: Expr.Vararg.t() + def vararg(meta \\ nil), do: %Expr.Vararg{meta: meta} + + # Variable and access + + @doc "Creates a variable reference" + @spec var(String.t(), Meta.t() | nil) :: Expr.Var.t() + def var(name, meta \\ nil), do: %Expr.Var{name: name, meta: meta} + + @doc """ + Creates a property access (obj.prop) + + ## Examples + + property(var("io"), "write") # io.write + """ + @spec property(Expr.t(), String.t(), Meta.t() | nil) :: Expr.Property.t() + def property(table, field, meta \\ nil) do + %Expr.Property{ + table: table, + field: field, + meta: meta + } + end + + @doc """ + Creates an index access (obj[index]) + + ## Examples + + index(var("t"), number(1)) # t[1] + """ + @spec index(Expr.t(), Expr.t(), Meta.t() | nil) :: Expr.Index.t() + def index(table, key, meta \\ nil) do + %Expr.Index{ + table: table, + key: key, + meta: meta + } + end + + # Operators + + @doc """ + Creates a binary operation. + + ## Operators + + - `:add`, `:sub`, `:mul`, `:div`, `:floor_div`, `:mod`, `:pow` + - `:concat` + - `:eq`, `:ne`, `:lt`, `:gt`, `:le`, `:ge` + - `:and`, `:or` + + ## Examples + + binop(:add, number(2), number(3)) # 2 + 3 + binop(:lt, var("x"), number(10)) # x < 10 + """ + @spec binop(atom(), Expr.t(), Expr.t(), Meta.t() | nil) :: Expr.BinOp.t() + def binop(op, left, right, meta \\ nil) do + %Expr.BinOp{ + op: op, + left: left, + right: right, + meta: meta + } + end + + @doc """ + Creates a unary operation. + + ## Operators + + - `:not` - logical not + - `:neg` - negation (-) + - `:len` - length operator (#) + + ## Examples + + unop(:neg, var("x")) # -x + unop(:not, var("flag")) # not flag + unop(:len, var("list")) # #list + """ + @spec unop(atom(), Expr.t(), Meta.t() | nil) :: Expr.UnOp.t() + def unop(op, operand, meta \\ nil) do + %Expr.UnOp{ + op: op, + operand: operand, + meta: meta + } + end + + # Table constructor + + @doc """ + Creates a table constructor. + + ## Field types + + - `{:list, expr}` - array-style field (value only) + - `{:record, key_expr, value_expr}` - key-value field + + ## Examples + + # Empty table: {} + table([]) + + # Array: {1, 2, 3} + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + + # Record: {x = 10, y = 20} + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + """ + @spec table([{:list, Expr.t()} | {:record, Expr.t(), Expr.t()}], Meta.t() | nil) :: Expr.Table.t() + def table(fields, meta \\ nil) do + %Expr.Table{ + fields: fields, + meta: meta + } + end + + # Function call + + @doc """ + Creates a function call. + + ## Examples + + call(var("print"), [string("hello")]) # print("hello") + call(property(var("io"), "write"), [string("test")]) # io.write("test") + """ + @spec call(Expr.t(), [Expr.t()], Meta.t() | nil) :: Expr.Call.t() + def call(func, args, meta \\ nil) do + %Expr.Call{ + func: func, + args: args, + meta: meta + } + end + + @doc """ + Creates a method call (obj:method(args)) + + ## Examples + + method_call(var("file"), "read", [string("*a")]) # file:read("*a") + """ + @spec method_call(Expr.t(), String.t(), [Expr.t()], Meta.t() | nil) :: Expr.MethodCall.t() + def method_call(object, method, args, meta \\ nil) do + %Expr.MethodCall{ + object: object, + method: method, + args: args, + meta: meta + } + end + + # Function expression + + @doc """ + Creates a function expression. + + ## Examples + + # function(x, y) return x + y end + function_expr(["x", "y"], [ + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + + # function(...) return ... end + function_expr([], [return_stmt([vararg()])], vararg: true) + """ + @spec function_expr([String.t()], [Stmt.t()], keyword()) :: Expr.Function.t() + def function_expr(params, body_stmts, opts \\ []) do + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + %Expr.Function{ + params: params_with_vararg, + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + # Statements + + @doc """ + Creates an assignment statement. + + ## Examples + + # x = 10 + assign([var("x")], [number(10)]) + + # x, y = 1, 2 + assign([var("x"), var("y")], [number(1), number(2)]) + """ + @spec assign([Expr.t()], [Expr.t()], Meta.t() | nil) :: Stmt.Assign.t() + def assign(targets, values, meta \\ nil) do + %Stmt.Assign{ + targets: targets, + values: values, + meta: meta + } + end + + @doc """ + Creates a local variable declaration. + + ## Examples + + # local x + local(["x"], []) + + # local x = 10 + local(["x"], [number(10)]) + + # local x, y = 1, 2 + local(["x", "y"], [number(1), number(2)]) + """ + @spec local([String.t()], [Expr.t()], Meta.t() | nil) :: Stmt.Local.t() + def local(names, values \\ [], meta \\ nil) do + %Stmt.Local{ + names: names, + values: values, + meta: meta + } + end + + @doc """ + Creates a local function declaration. + + ## Examples + + # local function add(a, b) return a + b end + local_func("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + """ + @spec local_func(String.t(), [String.t()], [Stmt.t()], keyword()) :: Stmt.LocalFunc.t() + def local_func(name, params, body_stmts, opts \\ []) do + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + %Stmt.LocalFunc{ + name: name, + params: params_with_vararg, + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a function declaration. + + ## Examples + + # function add(a, b) return a + b end + func_decl("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + + # function math.add(a, b) return a + b end + func_decl(["math", "add"], ["a", "b"], [...]) + """ + @spec func_decl(String.t() | [String.t()], [String.t()], [Stmt.t()], keyword()) :: Stmt.FuncDecl.t() + def func_decl(name, params, body_stmts, opts \\ []) do + name_parts = if is_binary(name), do: [name], else: name + + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + is_method = Keyword.get(opts, :is_method, false) + + %Stmt.FuncDecl{ + name: name_parts, + params: params_with_vararg, + body: block(body_stmts), + is_method: is_method, + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a function call statement. + + ## Examples + + call_stmt(call(var("print"), [string("hello")])) + """ + @spec call_stmt(Expr.Call.t() | Expr.MethodCall.t(), Meta.t() | nil) :: Stmt.CallStmt.t() + def call_stmt(call_expr, meta \\ nil) do + %Stmt.CallStmt{ + call: call_expr, + meta: meta + } + end + + @doc """ + Creates an if statement. + + ## Examples + + # if x > 0 then print(x) end + if_stmt( + binop(:gt, var("x"), number(0)), + [call_stmt(call(var("print"), [var("x")]))] + ) + + # if x > 0 then ... elseif x < 0 then ... else ... end + if_stmt( + binop(:gt, var("x"), number(0)), + [call_stmt(...)], + elseif: [{binop(:lt, var("x"), number(0)), [call_stmt(...)]}], + else: [call_stmt(...)] + ) + """ + @spec if_stmt(Expr.t(), [Stmt.t()], keyword()) :: Stmt.If.t() + def if_stmt(condition, then_stmts, opts \\ []) do + %Stmt.If{ + condition: condition, + then_block: block(then_stmts), + elseifs: Keyword.get(opts, :elseif, []) |> Enum.map(fn {c, s} -> {c, block(s)} end), + else_block: if(else_stmts = Keyword.get(opts, :else), do: block(else_stmts)), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a while loop. + + ## Examples + + # while x > 0 do x = x - 1 end + while_stmt( + binop(:gt, var("x"), number(0)), + [assign([var("x")], [binop(:sub, var("x"), number(1))])] + ) + """ + @spec while_stmt(Expr.t(), [Stmt.t()], Meta.t() | nil) :: Stmt.While.t() + def while_stmt(condition, body_stmts, meta \\ nil) do + %Stmt.While{ + condition: condition, + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a repeat-until loop. + + ## Examples + + # repeat x = x - 1 until x <= 0 + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + """ + @spec repeat_stmt([Stmt.t()], Expr.t(), Meta.t() | nil) :: Stmt.Repeat.t() + def repeat_stmt(body_stmts, condition, meta \\ nil) do + %Stmt.Repeat{ + body: block(body_stmts), + condition: condition, + meta: meta + } + end + + @doc """ + Creates a numeric for loop. + + ## Examples + + # for i = 1, 10 do print(i) end + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + + # for i = 1, 10, 2 do print(i) end + for_num("i", number(1), number(10), [...], step: number(2)) + """ + @spec for_num(String.t(), Expr.t(), Expr.t(), [Stmt.t()], keyword()) :: Stmt.ForNum.t() + def for_num(var_name, start, limit, body_stmts, opts \\ []) do + %Stmt.ForNum{ + var: var_name, + start: start, + limit: limit, + step: Keyword.get(opts, :step), + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a generic for loop (for-in). + + ## Examples + + # for k, v in pairs(t) do print(k, v) end + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + """ + @spec for_in([String.t()], [Expr.t()], [Stmt.t()], Meta.t() | nil) :: Stmt.ForIn.t() + def for_in(vars, iterators, body_stmts, meta \\ nil) do + %Stmt.ForIn{ + vars: vars, + iterators: iterators, + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a do block. + + ## Examples + + # do local x = 10; print(x) end + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + """ + @spec do_block([Stmt.t()], Meta.t() | nil) :: Stmt.Do.t() + def do_block(body_stmts, meta \\ nil) do + %Stmt.Do{ + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a return statement. + + ## Examples + + # return + return_stmt([]) + + # return 42 + return_stmt([number(42)]) + + # return x, y + return_stmt([var("x"), var("y")]) + """ + @spec return_stmt([Expr.t()], Meta.t() | nil) :: Stmt.Return.t() + def return_stmt(values, meta \\ nil) do + %Stmt.Return{ + values: values, + meta: meta + } + end + + @doc "Creates a break statement" + @spec break_stmt(Meta.t() | nil) :: Stmt.Break.t() + def break_stmt(meta \\ nil), do: %Stmt.Break{meta: meta} + + @doc "Creates a goto statement" + @spec goto_stmt(String.t(), Meta.t() | nil) :: Stmt.Goto.t() + def goto_stmt(label, meta \\ nil), do: %Stmt.Goto{label: label, meta: meta} + + @doc "Creates a label" + @spec label(String.t(), Meta.t() | nil) :: Stmt.Label.t() + def label(name, meta \\ nil), do: %Stmt.Label{name: name, meta: meta} +end diff --git a/lib/lua/ast/expr.ex b/lib/lua/ast/expr.ex new file mode 100644 index 0000000..90b238d --- /dev/null +++ b/lib/lua/ast/expr.ex @@ -0,0 +1,216 @@ +defmodule Lua.AST.Expr do + @moduledoc """ + Expression AST nodes for Lua. + + All expression nodes include a `meta` field for position tracking. + """ + + alias Lua.AST.Meta + + @type t :: + Nil.t() + | Bool.t() + | Number.t() + | String.t() + | Var.t() + | BinOp.t() + | UnOp.t() + | Table.t() + | Call.t() + | MethodCall.t() + | Index.t() + | Property.t() + | Function.t() + | Vararg.t() + + defmodule Nil do + @moduledoc "Represents the `nil` literal" + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end + + defmodule Bool do + @moduledoc "Represents boolean literals (`true` or `false`)" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: boolean(), meta: Meta.t() | nil} + end + + defmodule Number do + @moduledoc "Represents numeric literals (integers and floats)" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: number(), meta: Meta.t() | nil} + end + + defmodule String do + @moduledoc "Represents string literals" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: String.t(), meta: Meta.t() | nil} + end + + defmodule Var do + @moduledoc "Represents a variable reference" + defstruct [:name, :meta] + @type t :: %__MODULE__{name: String.t(), meta: Meta.t() | nil} + end + + defmodule BinOp do + @moduledoc """ + Represents a binary operation. + + Operators: + - Arithmetic: `:add`, `:sub`, `:mul`, `:div`, `:floordiv`, `:mod`, `:pow` + - Comparison: `:eq`, `:ne`, `:lt`, `:le`, `:gt`, `:ge` + - Logical: `:and`, `:or` + - String: `:concat` + """ + defstruct [:op, :left, :right, :meta] + + @type op :: + :add + | :sub + | :mul + | :div + | :floordiv + | :mod + | :pow + | :eq + | :ne + | :lt + | :le + | :gt + | :ge + | :and + | :or + | :concat + + @type t :: %__MODULE__{ + op: op(), + left: Lua.AST.Expr.t(), + right: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule UnOp do + @moduledoc """ + Represents a unary operation. + + Operators: + - `:not` - logical not + - `:neg` - arithmetic negation (-) + - `:len` - length operator (#) + """ + defstruct [:op, :operand, :meta] + + @type op :: :not | :neg | :len + + @type t :: %__MODULE__{ + op: op(), + operand: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule Table do + @moduledoc """ + Represents a table constructor: `{...}` + + Fields can be: + - List entries: `{1, 2, 3}` -> `[{:list, expr}, ...]` + - Key-value pairs: `{a = 1}` -> `[{:pair, key_expr, val_expr}, ...]` + - Computed keys: `{["key"] = value}` -> `[{:pair, key_expr, val_expr}, ...]` + """ + defstruct [:fields, :meta] + + @type field :: + {:list, Lua.AST.Expr.t()} + | {:pair, Lua.AST.Expr.t(), Lua.AST.Expr.t()} + + @type t :: %__MODULE__{ + fields: [field()], + meta: Meta.t() | nil + } + end + + defmodule Call do + @moduledoc """ + Represents a function call: `func(args)` + """ + defstruct [:func, :args, :meta] + + @type t :: %__MODULE__{ + func: Lua.AST.Expr.t(), + args: [Lua.AST.Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule MethodCall do + @moduledoc """ + Represents a method call: `obj:method(args)` + + This is syntactic sugar for `obj.method(obj, args)` in Lua. + """ + defstruct [:object, :method, :args, :meta] + + @type t :: %__MODULE__{ + object: Lua.AST.Expr.t(), + method: String.t(), + args: [Lua.AST.Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Index do + @moduledoc """ + Represents indexing with brackets: `table[key]` + """ + defstruct [:table, :key, :meta] + + @type t :: %__MODULE__{ + table: Lua.AST.Expr.t(), + key: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule Property do + @moduledoc """ + Represents property access: `table.field` + + This is syntactic sugar for `table["field"]` in Lua. + """ + defstruct [:table, :field, :meta] + + @type t :: %__MODULE__{ + table: Lua.AST.Expr.t(), + field: String.t(), + meta: Meta.t() | nil + } + end + + defmodule Function do + @moduledoc """ + Represents a function expression: `function(params) body end` + + Params can include: + - Named parameters: `["a", "b", "c"]` + - Vararg: `{:vararg}` as the last element + """ + defstruct [:params, :body, :meta] + + @type param :: String.t() | :vararg + + @type t :: %__MODULE__{ + params: [param()], + body: Lua.AST.Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Vararg do + @moduledoc "Represents the vararg expression: `...`" + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end +end diff --git a/lib/lua/ast/meta.ex b/lib/lua/ast/meta.ex new file mode 100644 index 0000000..1685c6e --- /dev/null +++ b/lib/lua/ast/meta.ex @@ -0,0 +1,89 @@ +defmodule Lua.AST.Meta do + @moduledoc """ + Position tracking metadata for AST nodes. + + Every AST node includes a `meta` field containing position information + for error reporting, source maps, and debugging. + """ + + @type position :: %{ + line: pos_integer(), + column: pos_integer(), + byte_offset: non_neg_integer() + } + + @type t :: %__MODULE__{ + start: position() | nil, + end: position() | nil, + metadata: map() + } + + defstruct start: nil, end: nil, metadata: %{} + + @doc """ + Creates a new Meta struct with start and end positions. + + ## Examples + + iex> Lua.AST.Meta.new( + ...> %{line: 1, column: 1, byte_offset: 0}, + ...> %{line: 1, column: 5, byte_offset: 4} + ...> ) + %Lua.AST.Meta{ + start: %{line: 1, column: 1, byte_offset: 0}, + end: %{line: 1, column: 5, byte_offset: 4}, + metadata: %{} + } + """ + @spec new(position() | nil, position() | nil, map()) :: t() + def new(start \\ nil, end_pos \\ nil, metadata \\ %{}) do + %__MODULE__{start: start, end: end_pos, metadata: metadata} + end + + @doc """ + Merges two Meta structs, taking the earliest start and latest end. + + Useful when combining multiple nodes into a single parent node. + + ## Examples + + iex> meta1 = Lua.AST.Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + iex> meta2 = Lua.AST.Meta.new(%{line: 1, column: 7, byte_offset: 6}, %{line: 1, column: 10, byte_offset: 9}) + iex> Lua.AST.Meta.merge(meta1, meta2) + %Lua.AST.Meta{ + start: %{line: 1, column: 1, byte_offset: 0}, + end: %{line: 1, column: 10, byte_offset: 9}, + metadata: %{} + } + """ + @spec merge(t(), t()) :: t() + def merge(%__MODULE__{start: start1, end: end1}, %__MODULE__{start: start2, end: end2}) do + new_start = earliest_position(start1, start2) + new_end = latest_position(end1, end2) + new(new_start, new_end) + end + + @doc """ + Adds metadata to an existing Meta struct. + """ + @spec add_metadata(t(), atom(), term()) :: t() + def add_metadata(%__MODULE__{metadata: metadata} = meta, key, value) do + %{meta | metadata: Map.put(metadata, key, value)} + end + + # Private helpers + + defp earliest_position(nil, pos), do: pos + defp earliest_position(pos, nil), do: pos + + defp earliest_position(pos1, pos2) do + if pos1.byte_offset <= pos2.byte_offset, do: pos1, else: pos2 + end + + defp latest_position(nil, pos), do: pos + defp latest_position(pos, nil), do: pos + + defp latest_position(pos1, pos2) do + if pos1.byte_offset >= pos2.byte_offset, do: pos1, else: pos2 + end +end diff --git a/lib/lua/ast/pretty_printer.ex b/lib/lua/ast/pretty_printer.ex new file mode 100644 index 0000000..c326632 --- /dev/null +++ b/lib/lua/ast/pretty_printer.ex @@ -0,0 +1,467 @@ +defmodule Lua.AST.PrettyPrinter do + @moduledoc """ + Converts AST back to Lua source code. + + Produces readable, properly indented Lua code from AST structures. + Useful for: + - Round-trip testing (parse → print → parse) + - Debugging AST transformations + - Code generation + + ## Examples + + ast = Parser.parse("local x = 2 + 2") + code = PrettyPrinter.print(ast) + # => "local x = 2 + 2\\n" + + # With custom indentation + PrettyPrinter.print(ast, indent: 4) + """ + + alias Lua.AST.{Chunk, Block, Expr, Stmt} + + @type ast_node :: + Chunk.t() + | Block.t() + | Expr.t() + | Stmt.t() + + @type opts :: [ + indent: pos_integer() + ] + + @doc """ + Converts an AST node to Lua source code. + + ## Options + + - `:indent` - Number of spaces per indentation level (default: 2) + + ## Examples + + PrettyPrinter.print(ast) + PrettyPrinter.print(ast, indent: 4) + """ + @spec print(ast_node, opts) :: String.t() + def print(node, opts \\ []) do + indent_size = Keyword.get(opts, :indent, 2) + do_print(node, 0, indent_size) + end + + # Chunk + defp do_print(%Chunk{block: block}, level, indent_size) do + do_print(block, level, indent_size) + end + + # Block + defp do_print(%Block{stmts: stmts}, level, indent_size) do + stmts + |> Enum.map(&do_print(&1, level, indent_size)) + |> Enum.join("\n") + |> Kernel.<>("\n") + end + + # Expressions + + defp do_print(%Expr.Nil{}, _level, _indent_size), do: "nil" + defp do_print(%Expr.Bool{value: true}, _level, _indent_size), do: "true" + defp do_print(%Expr.Bool{value: false}, _level, _indent_size), do: "false" + + defp do_print(%Expr.Number{value: n}, _level, _indent_size) do + # Format numbers nicely + if is_float(n) and Float.floor(n) == n do + # Integer-valued float + "#{trunc(n)}.0" + else + "#{n}" + end + end + + defp do_print(%Expr.String{value: s}, _level, _indent_size) do + # Escape special characters + escaped = + s + |> String.replace("\\", "\\\\") + |> String.replace("\"", "\\\"") + |> String.replace("\n", "\\n") + |> String.replace("\t", "\\t") + + "\"#{escaped}\"" + end + + defp do_print(%Expr.Vararg{}, _level, _indent_size), do: "..." + + defp do_print(%Expr.Var{name: name}, _level, _indent_size), do: name + + defp do_print(%Expr.BinOp{op: op, left: left, right: right}, level, indent_size) do + left_str = print_expr_with_parens(left, op, :left, level, indent_size) + right_str = print_expr_with_parens(right, op, :right, level, indent_size) + op_str = format_binop(op) + + "#{left_str} #{op_str} #{right_str}" + end + + defp do_print(%Expr.UnOp{op: op, operand: operand}, level, indent_size) do + operand_str = print_expr_with_parens(operand, op, :operand, level, indent_size) + op_str = format_unop(op) + + "#{op_str}#{operand_str}" + end + + defp do_print(%Expr.Table{fields: fields}, level, indent_size) do + if fields == [] do + "{}" + else + field_strs = + Enum.map(fields, fn + {:list, value} -> + do_print(value, level + 1, indent_size) + + {:record, key, value} -> + key_str = format_table_key(key, level + 1, indent_size) + value_str = do_print(value, level + 1, indent_size) + "#{key_str} = #{value_str}" + end) + + "{#{Enum.join(field_strs, ", ")}}" + end + end + + defp do_print(%Expr.Call{func: func, args: args}, level, indent_size) do + func_str = do_print(func, level, indent_size) + args_str = Enum.map(args, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{func_str}(#{args_str})" + end + + defp do_print(%Expr.MethodCall{object: obj, method: method, args: args}, level, indent_size) do + obj_str = do_print(obj, level, indent_size) + args_str = Enum.map(args, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{obj_str}:#{method}(#{args_str})" + end + + defp do_print(%Expr.Index{table: table, key: key}, level, indent_size) do + table_str = do_print(table, level, indent_size) + key_str = do_print(key, level, indent_size) + + "#{table_str}[#{key_str}]" + end + + defp do_print(%Expr.Property{table: table, field: field}, level, indent_size) do + table_str = do_print(table, level, indent_size) + "#{table_str}.#{field}" + end + + defp do_print(%Expr.Function{params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + name -> name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "function(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + # Statements + + defp do_print(%Stmt.Assign{targets: targets, values: values}, level, indent_size) do + targets_str = Enum.map(targets, &do_print(&1, level, indent_size)) |> Enum.join(", ") + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{indent(level, indent_size)}#{targets_str} = #{values_str}" + end + + defp do_print(%Stmt.Local{names: names, values: values}, level, indent_size) do + names_str = Enum.join(names, ", ") + + if values && values != [] do + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + "#{indent(level, indent_size)}local #{names_str} = #{values_str}" + else + "#{indent(level, indent_size)}local #{names_str}" + end + end + + defp do_print(%Stmt.LocalFunc{name: name, params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + name -> name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}local function #{name}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.FuncDecl{name: name, params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + param_name -> param_name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}function #{format_func_name(name)}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.CallStmt{call: call}, level, indent_size) do + "#{indent(level, indent_size)}#{do_print(call, level, indent_size)}" + end + + defp do_print(%Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block}, level, indent_size) do + cond_str = do_print(cond, level, indent_size) + then_str = print_block_body(then_block, level + 1, indent_size) + + elseif_strs = + Enum.map(elseifs, fn {c, b} -> + c_str = do_print(c, level, indent_size) + b_str = print_block_body(b, level + 1, indent_size) + "#{indent(level, indent_size)}elseif #{c_str} then\n#{b_str}" + end) + + else_str = + if else_block do + b_str = print_block_body(else_block, level + 1, indent_size) + "#{indent(level, indent_size)}else\n#{b_str}" + else + nil + end + + parts = ["#{indent(level, indent_size)}if #{cond_str} then\n#{then_str}"] ++ elseif_strs + + parts = + if else_str do + parts ++ [else_str] + else + parts + end + + Enum.join(parts, "") <> "#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.While{condition: cond, body: body}, level, indent_size) do + cond_str = do_print(cond, level, indent_size) + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}while #{cond_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.Repeat{body: body, condition: cond}, level, indent_size) do + body_str = print_block_body(body, level + 1, indent_size) + cond_str = do_print(cond, level, indent_size) + + "#{indent(level, indent_size)}repeat\n#{body_str}#{indent(level, indent_size)}until #{cond_str}" + end + + defp do_print(%Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body}, level, indent_size) do + start_str = do_print(start, level, indent_size) + limit_str = do_print(limit, level, indent_size) + body_str = print_block_body(body, level + 1, indent_size) + + step_str = + if step do + ", #{do_print(step, level, indent_size)}" + else + "" + end + + "#{indent(level, indent_size)}for #{var} = #{start_str}, #{limit_str}#{step_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.ForIn{vars: vars, iterators: iterators, body: body}, level, indent_size) do + vars_str = Enum.join(vars, ", ") + iterators_str = Enum.map(iterators, &do_print(&1, level, indent_size)) |> Enum.join(", ") + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}for #{vars_str} in #{iterators_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.Do{body: body}, level, indent_size) do + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Stmt.Return{values: values}, level, indent_size) do + if values == [] do + "#{indent(level, indent_size)}return" + else + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + "#{indent(level, indent_size)}return #{values_str}" + end + end + + defp do_print(%Stmt.Break{}, level, indent_size) do + "#{indent(level, indent_size)}break" + end + + defp do_print(%Stmt.Goto{label: label}, level, indent_size) do + "#{indent(level, indent_size)}goto #{label}" + end + + defp do_print(%Stmt.Label{name: name}, level, indent_size) do + "#{indent(level, indent_size)}::#{name}::" + end + + # Helpers + + defp indent(level, indent_size) do + String.duplicate(" ", level * indent_size) + end + + defp print_block_body(%Block{stmts: stmts}, level, indent_size) do + stmts + |> Enum.map(&do_print(&1, level, indent_size)) + |> Enum.join("\n") + |> Kernel.<>("\n") + end + + # Add parentheses when needed for operator precedence + defp print_expr_with_parens(expr, parent_op, position, level, indent_size) do + expr_str = do_print(expr, level, indent_size) + + if needs_parens?(expr, parent_op, position) do + "(#{expr_str})" + else + expr_str + end + end + + # Determine if parentheses are needed based on precedence + defp needs_parens?(expr, parent_op, position) do + case expr do + %Expr.BinOp{op: child_op} -> + parent_prec = binop_precedence(parent_op) + child_prec = binop_precedence(child_op) + + cond do + # Lower precedence always needs parens + child_prec < parent_prec -> true + # Same precedence needs parens if associativity doesn't match + child_prec == parent_prec -> needs_parens_same_prec?(parent_op, position) + # Higher precedence never needs parens + true -> false + end + + %Expr.UnOp{} -> + # Unary ops have high precedence, rarely need parens + case parent_op do + :pow -> true # -2^3 should be -(2^3) + _ -> false + end + + _ -> + false + end + end + + defp needs_parens_same_prec?(op, position) do + # Right-associative operators need parens on the left + # Left-associative operators need parens on the right + case {is_right_assoc?(op), position} do + {true, :left} -> true + {false, :right} -> true + _ -> false + end + end + + defp is_right_assoc?(op) do + op in [:concat, :pow] + end + + defp binop_precedence(op) do + case op do + :or -> 1 + :and -> 2 + :lt -> 3 + :gt -> 3 + :le -> 3 + :ge -> 3 + :ne -> 3 + :eq -> 3 + :concat -> 4 + :add -> 5 + :sub -> 5 + :mul -> 6 + :div -> 6 + :floor_div -> 6 + :mod -> 6 + :pow -> 8 + _ -> 0 + end + end + + defp format_binop(op) do + case op do + :add -> "+" + :sub -> "-" + :mul -> "*" + :div -> "/" + :floor_div -> "//" + :mod -> "%" + :pow -> "^" + :concat -> ".." + :eq -> "==" + :ne -> "~=" + :lt -> "<" + :gt -> ">" + :le -> "<=" + :ge -> ">=" + :and -> "and" + :or -> "or" + _ -> "" + end + end + + defp format_unop(op) do + case op do + :not -> "not " + :neg -> "-" + :len -> "#" + _ -> "" + end + end + + defp format_table_key(key, level, indent_size) do + case key do + %Expr.String{value: s} -> + # If it's a valid identifier, use shorthand + if valid_identifier?(s) do + s + else + "[#{do_print(key, level, indent_size)}]" + end + + _ -> + "[#{do_print(key, level, indent_size)}]" + end + end + + defp valid_identifier?(s) do + # Check if string is a valid Lua identifier + Regex.match?(~r/^[a-zA-Z_][a-zA-Z0-9_]*$/, s) and not lua_keyword?(s) + end + + defp lua_keyword?(s) do + s in ~w(and break do else elseif end false for function goto if in local nil not or repeat return then true until while) + end + + defp format_func_name(parts) when is_list(parts) do + Enum.join(parts, ".") + end + + defp format_func_name(name) when is_binary(name) do + name + end +end diff --git a/lib/lua/ast/stmt.ex b/lib/lua/ast/stmt.ex new file mode 100644 index 0000000..331c9f3 --- /dev/null +++ b/lib/lua/ast/stmt.ex @@ -0,0 +1,262 @@ +defmodule Lua.AST.Stmt do + @moduledoc """ + Statement AST nodes for Lua. + + All statement nodes include a `meta` field for position tracking. + """ + + alias Lua.AST.{Meta, Expr, Block} + + @type t :: + Assign.t() + | Local.t() + | LocalFunc.t() + | FuncDecl.t() + | CallStmt.t() + | If.t() + | While.t() + | Repeat.t() + | ForNum.t() + | ForIn.t() + | Do.t() + | Return.t() + | Break.t() + | Goto.t() + | Label.t() + + defmodule Assign do + @moduledoc """ + Represents an assignment statement: `targets = values` + + Both targets and values can be lists for multiple assignment: + `a, b = 1, 2` + """ + defstruct [:targets, :values, :meta] + + @type t :: %__MODULE__{ + targets: [Expr.t()], + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Local do + @moduledoc """ + Represents a local variable declaration: `local names = values` + + Values can be empty for declaration without initialization. + """ + defstruct [:names, :values, :meta] + + @type t :: %__MODULE__{ + names: [String.t()], + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule LocalFunc do + @moduledoc """ + Represents a local function declaration: `local function name(params) body end` + """ + defstruct [:name, :params, :body, :meta] + + @type t :: %__MODULE__{ + name: String.t(), + params: [Expr.Function.param()], + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule FuncDecl do + @moduledoc """ + Represents a function declaration: `function name(params) body end` + + The name can be a path for nested names: + - `function foo() end` -> name: `["foo"]`, is_method: false + - `function a.b.c() end` -> name: `["a", "b", "c"]`, is_method: false + - `function obj:method() end` -> name: `["obj", "method"]`, is_method: true + + When `is_method` is true, an implicit `self` parameter is added. + """ + defstruct [:name, :params, :body, :is_method, :meta] + + @type t :: %__MODULE__{ + name: [String.t()], + params: [Expr.Function.param()], + body: Block.t(), + is_method: boolean(), + meta: Meta.t() | nil + } + end + + defmodule CallStmt do + @moduledoc """ + Represents a function call as a statement. + + In Lua, function calls can be expressions or statements. + """ + defstruct [:call, :meta] + + @type t :: %__MODULE__{ + call: Expr.Call.t() | Expr.MethodCall.t(), + meta: Meta.t() | nil + } + end + + defmodule If do + @moduledoc """ + Represents an if statement with optional elseif and else clauses. + + ```lua + if condition then + block + elseif condition2 then + block2 + else + else_block + end + ``` + """ + defstruct [:condition, :then_block, :elseifs, :else_block, :meta] + + @type elseif_clause :: {Expr.t(), Block.t()} + + @type t :: %__MODULE__{ + condition: Expr.t(), + then_block: Block.t(), + elseifs: [elseif_clause()], + else_block: Block.t() | nil, + meta: Meta.t() | nil + } + end + + defmodule While do + @moduledoc """ + Represents a while loop: `while condition do block end` + """ + defstruct [:condition, :body, :meta] + + @type t :: %__MODULE__{ + condition: Expr.t(), + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Repeat do + @moduledoc """ + Represents a repeat-until loop: `repeat block until condition` + """ + defstruct [:body, :condition, :meta] + + @type t :: %__MODULE__{ + body: Block.t(), + condition: Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule ForNum do + @moduledoc """ + Represents a numeric for loop: `for var = start, limit, step do block end` + + The step is optional and defaults to 1. + """ + defstruct [:var, :start, :limit, :step, :body, :meta] + + @type t :: %__MODULE__{ + var: String.t(), + start: Expr.t(), + limit: Expr.t(), + step: Expr.t() | nil, + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule ForIn do + @moduledoc """ + Represents a generic for loop: `for vars in exprs do block end` + + ```lua + for k, v in pairs(t) do + -- block + end + ``` + """ + defstruct [:vars, :iterators, :body, :meta] + + @type t :: %__MODULE__{ + vars: [String.t()], + iterators: [Expr.t()], + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Do do + @moduledoc """ + Represents a do block: `do block end` + + Used to create a new scope. + """ + defstruct [:body, :meta] + + @type t :: %__MODULE__{ + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Return do + @moduledoc """ + Represents a return statement: `return exprs` + + Can return multiple values. + """ + defstruct [:values, :meta] + + @type t :: %__MODULE__{ + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Break do + @moduledoc """ + Represents a break statement: `break` + """ + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end + + defmodule Goto do + @moduledoc """ + Represents a goto statement: `goto label` + + Introduced in Lua 5.2. + """ + defstruct [:label, :meta] + + @type t :: %__MODULE__{ + label: String.t(), + meta: Meta.t() | nil + } + end + + defmodule Label do + @moduledoc """ + Represents a label: `::label::` + + Introduced in Lua 5.2. + """ + defstruct [:name, :meta] + + @type t :: %__MODULE__{ + name: String.t(), + meta: Meta.t() | nil + } + end +end diff --git a/lib/lua/ast/walker.ex b/lib/lua/ast/walker.ex new file mode 100644 index 0000000..5adc6cc --- /dev/null +++ b/lib/lua/ast/walker.ex @@ -0,0 +1,287 @@ +defmodule Lua.AST.Walker do + @moduledoc """ + AST traversal utilities using the visitor pattern. + + Provides functions for walking, mapping, and reducing over AST nodes. + + ## Examples + + # Simple traversal (side effects) + Walker.walk(ast, fn node -> + IO.inspect(node) + end) + + # Transform AST (double all numbers) + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + # Accumulate values (collect all variable names) + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _node, acc -> acc + end) + + # Post-order traversal + Walker.walk(ast, fn node -> ... end, order: :post) + """ + + alias Lua.AST.{Chunk, Block, Expr, Stmt} + + @type ast_node :: + Chunk.t() + | Block.t() + | Expr.t() + | Stmt.t() + + @type visitor :: (ast_node -> any()) + @type mapper :: (ast_node -> ast_node) + @type reducer :: (ast_node, acc :: any() -> any()) + + @type order :: :pre | :post + + @doc """ + Walks the AST, calling the visitor function for each node. + + The visitor is called in pre-order by default (parent before children). + Use `order: :post` for post-order traversal (children before parent). + + ## Options + + - `:order` - `:pre` (default) or `:post` + + ## Examples + + Walker.walk(ast, fn + %Expr.Number{value: n} -> IO.puts("Found number: \#{n}") + _node -> :ok + end) + """ + @spec walk(ast_node, visitor, keyword()) :: :ok + def walk(node, visitor, opts \\ []) do + order = Keyword.get(opts, :order, :pre) + do_walk(node, visitor, order) + :ok + end + + @doc """ + Maps over the AST, transforming nodes with the mapper function. + + The mapper is called in post-order (children before parent) to ensure + transformations propagate upward correctly. + + ## Examples + + # Double all numbers + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + """ + @spec map(ast_node, mapper) :: ast_node + def map(node, mapper) do + do_map(node, mapper) + end + + @doc """ + Reduces the AST to a single value by calling the reducer function for each node. + + The reducer is called in pre-order by default. + + ## Examples + + # Collect all variable names + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _node, acc -> acc + end) + + # Count all nodes + Walker.reduce(ast, 0, fn _node, acc -> acc + 1 end) + """ + @spec reduce(ast_node, acc, reducer) :: acc when acc: any() + def reduce(node, initial, reducer) do + do_reduce(node, initial, reducer) + end + + # Private implementation + + # Walk in pre-order or post-order + defp do_walk(node, visitor, :pre) do + visitor.(node) + walk_children(node, visitor, :pre) + end + + defp do_walk(node, visitor, :post) do + walk_children(node, visitor, :post) + visitor.(node) + end + + defp walk_children(node, visitor, order) do + children(node) + |> Enum.each(fn child -> do_walk(child, visitor, order) end) + end + + # Map (post-order to transform bottom-up) + defp do_map(node, mapper) do + mapped_children = + case node do + # Chunk + %Chunk{block: block} = chunk -> + %{chunk | block: do_map(block, mapper)} + + # Block + %Block{stmts: stmts} = block -> + %{block | stmts: Enum.map(stmts, &do_map(&1, mapper))} + + # Expressions + %Expr.BinOp{left: left, right: right} = expr -> + %{expr | left: do_map(left, mapper), right: do_map(right, mapper)} + + %Expr.UnOp{operand: operand} = expr -> + %{expr | operand: do_map(operand, mapper)} + + %Expr.Table{fields: fields} = expr -> + mapped_fields = + Enum.map(fields, fn + {:list, value} -> {:list, do_map(value, mapper)} + {:record, key, value} -> {:record, do_map(key, mapper), do_map(value, mapper)} + end) + + %{expr | fields: mapped_fields} + + %Expr.Call{func: func, args: args} = expr -> + %{expr | func: do_map(func, mapper), args: Enum.map(args, &do_map(&1, mapper))} + + %Expr.MethodCall{object: obj, args: args} = expr -> + %{expr | object: do_map(obj, mapper), args: Enum.map(args, &do_map(&1, mapper))} + + %Expr.Index{table: table, key: key} = expr -> + %{expr | table: do_map(table, mapper), key: do_map(key, mapper)} + + %Expr.Property{table: table} = expr -> + %{expr | table: do_map(table, mapper)} + + %Expr.Function{body: body} = expr -> + %{expr | body: do_map(body, mapper)} + + # Statements + %Stmt.Assign{targets: targets, values: values} = stmt -> + %{stmt | targets: Enum.map(targets, &do_map(&1, mapper)), values: Enum.map(values, &do_map(&1, mapper))} + + %Stmt.Local{values: values} = stmt when is_list(values) -> + %{stmt | values: Enum.map(values, &do_map(&1, mapper))} + + %Stmt.Local{} = stmt -> + stmt + + %Stmt.LocalFunc{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Stmt.FuncDecl{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Stmt.CallStmt{call: call} = stmt -> + %{stmt | call: do_map(call, mapper)} + + %Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block} = stmt -> + mapped_elseifs = Enum.map(elseifs, fn {c, b} -> {do_map(c, mapper), do_map(b, mapper)} end) + mapped_else = if else_block, do: do_map(else_block, mapper), else: nil + + %{stmt | + condition: do_map(cond, mapper), + then_block: do_map(then_block, mapper), + elseifs: mapped_elseifs, + else_block: mapped_else + } + + %Stmt.While{condition: cond, body: body} = stmt -> + %{stmt | condition: do_map(cond, mapper), body: do_map(body, mapper)} + + %Stmt.Repeat{body: body, condition: cond} = stmt -> + %{stmt | body: do_map(body, mapper), condition: do_map(cond, mapper)} + + %Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body} = stmt -> + mapped_step = if step, do: do_map(step, mapper), else: nil + + %{stmt | + start: do_map(start, mapper), + limit: do_map(limit, mapper), + step: mapped_step, + body: do_map(body, mapper) + } + + %Stmt.ForIn{vars: vars, iterators: iterators, body: body} = stmt -> + %{stmt | iterators: Enum.map(iterators, &do_map(&1, mapper)), body: do_map(body, mapper)} + + %Stmt.Do{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Stmt.Return{values: values} = stmt -> + %{stmt | values: Enum.map(values, &do_map(&1, mapper))} + + # Leaf nodes (no children) + _ -> + node + end + + mapper.(mapped_children) + end + + # Reduce (pre-order accumulation) + defp do_reduce(node, acc, reducer) do + acc = reducer.(node, acc) + + children(node) + |> Enum.reduce(acc, fn child, acc -> do_reduce(child, acc, reducer) end) + end + + # Extract children for traversal + defp children(node) do + case node do + # Chunk + %Chunk{block: block} -> [block] + + # Block + %Block{stmts: stmts} -> stmts + + # Expressions with children + %Expr.BinOp{left: left, right: right} -> [left, right] + %Expr.UnOp{operand: operand} -> [operand] + %Expr.Table{fields: fields} -> extract_table_fields(fields) + %Expr.Call{func: func, args: args} -> [func | args] + %Expr.MethodCall{object: obj, args: args} -> [obj | args] + %Expr.Index{table: table, key: key} -> [table, key] + %Expr.Property{table: table} -> [table] + %Expr.Function{body: body} -> [body] + + # Statements with children + %Stmt.Assign{targets: targets, values: values} -> targets ++ values + %Stmt.Local{values: values} when is_list(values) -> values + %Stmt.LocalFunc{body: body} -> [body] + %Stmt.FuncDecl{body: body} -> [body] + %Stmt.CallStmt{call: call} -> [call] + %Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block} -> + elseif_nodes = Enum.flat_map(elseifs, fn {c, b} -> [c, b] end) + [cond, then_block | elseif_nodes] ++ if(else_block, do: [else_block], else: []) + %Stmt.While{condition: cond, body: body} -> [cond, body] + %Stmt.Repeat{body: body, condition: cond} -> [body, cond] + %Stmt.ForNum{start: start, limit: limit, step: step, body: body} -> + [start, limit] ++ if(step, do: [step], else: []) ++ [body] + %Stmt.ForIn{iterators: iterators, body: body} -> iterators ++ [body] + %Stmt.Do{body: body} -> [body] + %Stmt.Return{values: values} -> values + + # Leaf nodes (no children) + _ -> [] + end + end + + defp extract_table_fields(fields) do + Enum.flat_map(fields, fn + {:list, value} -> [value] + {:record, key, value} -> [key, value] + end) + end +end diff --git a/lib/lua/lexer.ex b/lib/lua/lexer.ex new file mode 100644 index 0000000..6774b77 --- /dev/null +++ b/lib/lua/lexer.ex @@ -0,0 +1,522 @@ +defmodule Lua.Lexer do + @moduledoc """ + Hand-written lexer for Lua 5.3 using Elixir binary pattern matching. + + Tokenizes Lua source code into a list of tokens with position tracking. + """ + + @type position :: %{line: pos_integer(), column: pos_integer(), byte_offset: non_neg_integer()} + @type token :: + {:keyword, atom(), position()} + | {:identifier, String.t(), position()} + | {:number, number(), position()} + | {:string, String.t(), position()} + | {:operator, atom(), position()} + | {:delimiter, atom(), position()} + | {:eof, position()} + + @keywords ~w( + and break do else elseif end false for function goto if in + local nil not or repeat return then true until while + ) + + @doc """ + Tokenizes Lua source code into a list of tokens. + + ## Examples + + iex> Lua.Lexer.tokenize("local x = 42") + {:ok, [ + {:keyword, :local, %{line: 1, column: 1, byte_offset: 0}}, + {:identifier, "x", %{line: 1, column: 7, byte_offset: 6}}, + {:operator, :assign, %{line: 1, column: 9, byte_offset: 8}}, + {:number, 42, %{line: 1, column: 11, byte_offset: 10}}, + {:eof, %{line: 1, column: 13, byte_offset: 12}} + ]} + """ + @spec tokenize(String.t()) :: {:ok, [token()]} | {:error, term()} + def tokenize(code) when is_binary(code) do + pos = %{line: 1, column: 1, byte_offset: 0} + do_tokenize(code, [], pos) + end + + # End of input + defp do_tokenize(<<>>, acc, pos) do + {:ok, Enum.reverse([{:eof, pos} | acc])} + end + + # Whitespace (space, tab) + defp do_tokenize(<>, acc, pos) when c in [?\s, ?\t] do + new_pos = advance_column(pos, 1) + do_tokenize(rest, acc, new_pos) + end + + # Newline (LF) + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + # Carriage return (CR, or CRLF) + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 2} + do_tokenize(rest, acc, new_pos) + end + + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + # Comments: single-line (--) or multi-line (--[[ ... ]]) + defp do_tokenize(<<"--[", rest::binary>>, acc, pos) do + # Check if it's a multi-line comment + case rest do + <<"[", _::binary>> -> + # Multi-line comment --[[ ... ]] + scan_multiline_comment(rest, acc, advance_column(pos, 3), 0) + + _ -> + # Single-line comment starting with --[ + scan_single_line_comment(rest, acc, advance_column(pos, 3)) + end + end + + defp do_tokenize(<<"--", rest::binary>>, acc, pos) do + scan_single_line_comment(rest, acc, advance_column(pos, 2)) + end + + # Strings: double-quoted + defp do_tokenize(<>, acc, pos) do + scan_string(rest, "", acc, advance_column(pos, 1), pos, ?") + end + + # Strings: single-quoted + defp do_tokenize(<>, acc, pos) do + scan_string(rest, "", acc, advance_column(pos, 1), pos, ?') + end + + # Strings: multi-line [[ ... ]] or [=[ ... ]=] + defp do_tokenize(<<"[", rest::binary>>, acc, pos) do + case scan_long_bracket(rest, 0) do + {:ok, equals, after_bracket} -> + scan_long_string(after_bracket, "", acc, advance_column(pos, 2 + equals), pos, equals) + + :error -> + # Not a long string, treat as delimiter + token = {:delimiter, :lbracket, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + end + + # Numbers: hex (0x, 0X) + defp do_tokenize(<<"0", x, rest::binary>>, acc, pos) when x in [?x, ?X] do + scan_hex_number(rest, "", acc, advance_column(pos, 2), pos) + end + + # Numbers: decimal or float + defp do_tokenize(<>, acc, pos) when c in ?0..?9 do + scan_number(<>, "", acc, pos, pos) + end + + # Three-character operators + defp do_tokenize(<<"...", rest::binary>>, acc, pos) do + token = {:operator, :vararg, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 3)) + end + + # Two-character operators + defp do_tokenize(<<"==", rest::binary>>, acc, pos) do + token = {:operator, :eq, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"~=", rest::binary>>, acc, pos) do + token = {:operator, :ne, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"<=", rest::binary>>, acc, pos) do + token = {:operator, :le, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<">=", rest::binary>>, acc, pos) do + token = {:operator, :ge, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"..", rest::binary>>, acc, pos) do + token = {:operator, :concat, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"::", rest::binary>>, acc, pos) do + token = {:delimiter, :double_colon, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"//", rest::binary>>, acc, pos) do + token = {:operator, :floordiv, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + # Single-character operators and delimiters + defp do_tokenize(<>, acc, pos) when c in [?+, ?-, ?*, ?/, ?%, ?^, ?#] do + op = + case c do + ?+ -> :add + ?- -> :sub + ?* -> :mul + ?/ -> :div + ?% -> :mod + ?^ -> :pow + ?# -> :len + end + + token = {:operator, op, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + defp do_tokenize(<>, acc, pos) when c in [?<, ?>, ?=] do + op = + case c do + ?< -> :lt + ?> -> :gt + ?= -> :assign + end + + token = {:operator, op, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + defp do_tokenize(<>, acc, pos) + when c in [?(, ?), ?{, ?}, ?], ?;, ?,, ?., ?:] do + delim = + case c do + ?( -> :lparen + ?) -> :rparen + ?{ -> :lbrace + ?} -> :rbrace + ?] -> :rbracket + ?; -> :semicolon + ?, -> :comma + ?. -> :dot + ?: -> :colon + end + + token = {:delimiter, delim, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + # Identifiers and keywords + defp do_tokenize(<>, acc, pos) + when c in ?a..?z or c in ?A..?Z or c == ?_ do + scan_identifier(<>, "", acc, pos, pos) + end + + # Unexpected character + defp do_tokenize(<>, _acc, pos) do + {:error, {:unexpected_character, c, pos}} + end + + # Scan single-line comment (skip until newline) + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 2} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<<>>, acc, pos) do + {:ok, Enum.reverse([{:eof, pos} | acc])} + end + + defp scan_single_line_comment(<<_, rest::binary>>, acc, pos) do + scan_single_line_comment(rest, acc, advance_column(pos, 1)) + end + + # Scan multi-line comment --[[ ... ]] or --[=[ ... ]=] + defp scan_multiline_comment(<<"[", rest::binary>>, acc, pos, level) do + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + + defp scan_multiline_comment(rest, acc, pos, _level) do + # Not a multi-line comment after all, treat as single-line + scan_single_line_comment(rest, acc, pos) + end + + defp scan_multiline_comment_content(<<"]", rest::binary>>, acc, pos, level) do + case try_close_long_bracket(rest, level, 0) do + {:ok, after_bracket} -> + new_pos = advance_column(pos, 2 + level) + do_tokenize(after_bracket, acc, new_pos) + + :error -> + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + end + + defp scan_multiline_comment_content(<>, acc, pos, level) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + scan_multiline_comment_content(rest, acc, new_pos, level) + end + + defp scan_multiline_comment_content(<<>>, _acc, pos, _level) do + {:error, {:unclosed_comment, pos}} + end + + defp scan_multiline_comment_content(<<_, rest::binary>>, acc, pos, level) do + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + + # Scan quoted string + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + # Closing quote + token = {:string, str_acc, start_pos} + do_tokenize(rest, [token | acc], pos) + end + + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + # Escape sequence + case escape_char(esc) do + {:ok, char} -> + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 2), start_pos, quote) + + :error -> + # Invalid escape, but continue scanning + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 2), start_pos, quote) + end + end + + defp scan_string(<>, _str_acc, _acc, pos, _start_pos, _quote) do + {:error, {:unclosed_string, pos}} + end + + defp scan_string(<<>>, _str_acc, _acc, pos, _start_pos, _quote) do + {:error, {:unclosed_string, pos}} + end + + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 1), start_pos, quote) + end + + # Escape character mapping + defp escape_char(?a), do: {:ok, ?\a} + defp escape_char(?b), do: {:ok, ?\b} + defp escape_char(?f), do: {:ok, ?\f} + defp escape_char(?n), do: {:ok, ?\n} + defp escape_char(?r), do: {:ok, ?\r} + defp escape_char(?t), do: {:ok, ?\t} + defp escape_char(?v), do: {:ok, ?\v} + defp escape_char(?\\), do: {:ok, ?\\} + defp escape_char(?"), do: {:ok, ?"} + defp escape_char(?'), do: {:ok, ?'} + defp escape_char(_), do: :error + + # Scan long bracket for level: [[ or [=[ or [==[ etc. + defp scan_long_bracket(rest, equals) do + case rest do + <<"=", after_eq::binary>> -> + scan_long_bracket(after_eq, equals + 1) + + <<"[", after_bracket::binary>> -> + {:ok, equals, after_bracket} + + _ -> + :error + end + end + + # Try to close long bracket: ]] or ]=] or ]==] etc. + defp try_close_long_bracket(rest, target_level, current_level) do + if current_level == target_level do + case rest do + <<"]", after_bracket::binary>> -> + {:ok, after_bracket} + + _ -> + :error + end + else + case rest do + <<"=", after_eq::binary>> -> + try_close_long_bracket(after_eq, target_level, current_level + 1) + + _ -> + :error + end + end + end + + # Scan long string [[ ... ]] or [=[ ... ]=] + defp scan_long_string(<<"]", rest::binary>>, str_acc, acc, pos, start_pos, level) do + case try_close_long_bracket(rest, level, 0) do + {:ok, after_bracket} -> + token = {:string, str_acc, start_pos} + new_pos = advance_column(pos, 2 + level) + do_tokenize(after_bracket, [token | acc], new_pos) + + :error -> + scan_long_string(rest, str_acc <> "]", acc, advance_column(pos, 1), start_pos, level) + end + end + + defp scan_long_string(<>, str_acc, acc, pos, start_pos, level) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + scan_long_string(rest, str_acc <> "\n", acc, new_pos, start_pos, level) + end + + defp scan_long_string(<<>>, _str_acc, _acc, pos, _start_pos, _level) do + {:error, {:unclosed_long_string, pos}} + end + + defp scan_long_string(<>, str_acc, acc, pos, start_pos, level) do + scan_long_string(rest, str_acc <> <>, acc, advance_column(pos, 1), start_pos, level) + end + + # Scan identifier or keyword + defp scan_identifier(<>, id_acc, acc, pos, start_pos) + when c in ?a..?z or c in ?A..?Z or c in ?0..?9 or c == ?_ do + scan_identifier(rest, id_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_identifier(rest, id_acc, acc, pos, start_pos) do + # Check if it's a keyword + token = + if id_acc in @keywords do + {:keyword, String.to_atom(id_acc), start_pos} + else + {:identifier, id_acc, start_pos} + end + + do_tokenize(rest, [token | acc], pos) + end + + # Scan decimal number + defp scan_number(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_number(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_number(<>, num_acc, acc, pos, start_pos) do + # Trailing dot is not part of the number + finalize_number(num_acc, <<".">> , acc, pos, start_pos) + end + + defp scan_number(<<".", c, rest::binary>>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + # Decimal point with digit following + scan_float(rest, num_acc <> "." <> <>, acc, advance_column(pos, 2), start_pos) + end + + defp scan_number(<<".", rest::binary>>, num_acc, acc, pos, start_pos) do + # Decimal point but no digit following - finalize number, reprocess "." + finalize_number(num_acc, <<".", rest::binary>>, acc, pos, start_pos) + end + + defp scan_number(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + # Scientific notation + scan_exponent(<>, num_acc, acc, pos, start_pos) + end + + defp scan_number(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan float part (after decimal point) + defp scan_float(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_float(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_float(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + scan_exponent(<>, num_acc, acc, pos, start_pos) + end + + defp scan_float(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan scientific notation exponent + defp scan_exponent(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] and sign in [?+, ?-] do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 2), start_pos) + end + + defp scan_exponent(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_exponent_digits(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_exponent_digits(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan hexadecimal number (0x...) + defp scan_hex_number(<>, hex_acc, acc, pos, start_pos) + when c in ?0..?9 or c in ?a..?f or c in ?A..?F do + scan_hex_number(rest, hex_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_hex_number(rest, hex_acc, acc, pos, start_pos) do + case Integer.parse(hex_acc, 16) do + {num, ""} -> + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + + _ -> + {:error, {:invalid_hex_number, start_pos}} + end + end + + # Finalize number token + defp finalize_number(num_str, rest, acc, pos, start_pos) do + case parse_number(num_str) do + {:ok, num} -> + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + + {:error, reason} -> + {:error, {reason, start_pos}} + end + end + + # Parse number string to integer or float + defp parse_number(num_str) do + cond do + String.contains?(num_str, ".") or String.contains?(num_str, "e") or + String.contains?(num_str, "E") -> + case Float.parse(num_str) do + {num, ""} -> {:ok, num} + _ -> {:error, :invalid_number} + end + + true -> + case Integer.parse(num_str) do + {num, ""} -> {:ok, num} + _ -> {:error, :invalid_number} + end + end + end + + # Position tracking helpers + defp advance_column(pos, n) do + %{pos | column: pos.column + n, byte_offset: pos.byte_offset + n} + end +end diff --git a/lib/lua/parser.ex b/lib/lua/parser.ex new file mode 100644 index 0000000..cdc684c --- /dev/null +++ b/lib/lua/parser.ex @@ -0,0 +1,1166 @@ +defmodule Lua.Parser do + @moduledoc """ + Hand-written recursive descent parser for Lua 5.3. + + Uses Pratt parsing for operator precedence in expressions. + """ + + alias Lua.AST.{Meta, Expr, Stmt, Block, Chunk} + alias Lua.Parser.Pratt + alias Lua.Lexer + + @type token :: Lexer.token() + @type parse_result(t) :: {:ok, t, [token()]} | {:error, term()} + + alias Lua.Parser.Error + + @doc """ + Parses Lua source code into an AST. + + Returns `{:ok, chunk}` on success or `{:error, formatted_error}` on failure. + The error is a beautifully formatted string with context and suggestions. + + ## Examples + + iex> Lua.Parser.parse("local x = 42") + {:ok, %Lua.AST.Chunk{...}} + + iex> {:error, error_msg} = Lua.Parser.parse("if x then") + iex> String.contains?(error_msg, "Parse Error") + true + """ + @spec parse(String.t()) :: {:ok, Chunk.t()} | {:error, String.t()} + def parse(code) when is_binary(code) do + case Lexer.tokenize(code) do + {:ok, tokens} -> + case parse_chunk(tokens) do + {:ok, chunk} -> + {:ok, chunk} + + {:error, reason} -> + error = convert_error(reason, code) + formatted = Error.format(error, code) + {:error, formatted} + end + + {:error, reason} -> + error = convert_lexer_error(reason, code) + formatted = Error.format(error, code) + {:error, formatted} + end + end + + @doc """ + Parses Lua source code and returns raw error information. + + Use this when you want to handle errors programmatically instead of + displaying them to users. + """ + @spec parse_raw(String.t()) :: {:ok, Chunk.t()} | {:error, term()} + def parse_raw(code) when is_binary(code) do + case Lexer.tokenize(code) do + {:ok, tokens} -> + parse_chunk(tokens) + + {:error, reason} -> + {:error, {:lexer_error, reason}} + end + end + + @doc """ + Parses a chunk (top-level block) from a token list. + """ + @spec parse_chunk([token()]) :: {:ok, Chunk.t()} | {:error, term()} + def parse_chunk(tokens) do + case parse_block(tokens) do + {:ok, block, rest} -> + case rest do + [{:eof, _}] -> + {:ok, Chunk.new(block)} + + [{type, _, pos} | _] -> + {:error, {:unexpected_token, type, pos, "Expected end of input"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Block parsing (sequence of statements) + defp parse_block(tokens) do + parse_block_acc(tokens, []) + end + + defp parse_block_acc(tokens, stmts) do + case peek(tokens) do + # Block terminators + {:keyword, terminator, _} when terminator in [:end, :else, :elseif, :until] -> + {:ok, Block.new(Enum.reverse(stmts)), tokens} + + {:eof, _} -> + {:ok, Block.new(Enum.reverse(stmts)), tokens} + + _ -> + case parse_stmt(tokens) do + {:ok, stmt, rest} -> + parse_block_acc(rest, [stmt | stmts]) + + {:error, reason} -> + {:error, reason} + end + end + end + + # Statement parsing (placeholder - will be implemented in Phase 3) + defp parse_stmt(tokens) do + case peek(tokens) do + {:keyword, :return, _} -> + parse_return(tokens) + + {:keyword, :local, _} -> + parse_local(tokens) + + {:keyword, :if, _} -> + parse_if(tokens) + + {:keyword, :while, _} -> + parse_while(tokens) + + {:keyword, :repeat, _} -> + parse_repeat(tokens) + + {:keyword, :for, _} -> + parse_for(tokens) + + {:keyword, :function, _} -> + parse_function_decl(tokens) + + {:keyword, :do, _} -> + parse_do(tokens) + + {:keyword, :break, _} -> + parse_break(tokens) + + {:keyword, :goto, _} -> + parse_goto(tokens) + + {:delimiter, :double_colon, _} -> + parse_label(tokens) + + # Semicolon (statement separator, optional) + {:delimiter, :semicolon, _} -> + {_, rest} = consume(tokens) + parse_stmt(rest) + + _ -> + # Try to parse as assignment or function call + parse_assign_or_call(tokens) + end + end + + # Placeholder implementations for statements (Phase 3) + defp parse_return([{:keyword, :return, pos} | rest]) do + case peek(rest) do + # End of block or statement + {:keyword, terminator, _} when terminator in [:end, :else, :elseif, :until] -> + {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest} + + {:eof, _} -> + {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest} + + {:delimiter, :semicolon, _} -> + {_, rest2} = consume(rest) + {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest2} + + _ -> + case parse_expr_list(rest) do + {:ok, exprs, rest2} -> + {:ok, %Stmt.Return{values: exprs, meta: Meta.new(pos)}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_local([{:keyword, :local, pos} | rest]) do + case peek(rest) do + {:keyword, :function, _} -> + # local function name() ... end + {_, rest2} = consume(rest) + + case expect(rest2, :identifier) do + {:ok, {_, name, _}, rest3} -> + with {:ok, _, rest4} <- expect(rest3, :delimiter, :lparen), + {:ok, params, rest5} <- parse_param_list(rest4), + {:ok, _, rest6} <- expect(rest5, :delimiter, :rparen), + {:ok, body, rest7} <- parse_block(rest6), + {:ok, _, rest8} <- expect(rest7, :keyword, :end) do + {:ok, %Stmt.LocalFunc{name: name, params: params, body: body, meta: Meta.new(pos)}, + rest8} + end + + {:error, reason} -> + {:error, reason} + end + + {:identifier, _, _} -> + # local name1, name2 = expr1, expr2 + case parse_name_list(rest) do + {:ok, names, rest2} -> + case peek(rest2) do + {:operator, :assign, _} -> + {_, rest3} = consume(rest2) + + case parse_expr_list(rest3) do + {:ok, values, rest4} -> + {:ok, %Stmt.Local{names: names, values: values, meta: Meta.new(pos)}, rest4} + + {:error, reason} -> + {:error, reason} + end + + _ -> + # Local without initialization + {:ok, %Stmt.Local{names: names, values: [], meta: Meta.new(pos)}, rest2} + end + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:error, {:unexpected_token, peek(rest), "Expected identifier or 'function' after 'local'"}} + end + end + + defp parse_if([{:keyword, :if, pos} | rest]) do + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :then), + {:ok, then_block, rest4} <- parse_block(rest3), + {:ok, elseifs, else_block, rest5} <- parse_elseifs(rest4) do + case expect(rest5, :keyword, :end) do + {:ok, _, rest6} -> + {:ok, + %Stmt.If{ + condition: condition, + then_block: then_block, + elseifs: elseifs, + else_block: else_block, + meta: Meta.new(pos) + }, rest6} + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_elseifs(tokens) do + case peek(tokens) do + {:keyword, :elseif, _} -> + {_, rest} = consume(tokens) + + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :then), + {:ok, block, rest4} <- parse_block(rest3), + {:ok, more_elseifs, else_block, rest5} <- parse_elseifs(rest4) do + {:ok, [{condition, block} | more_elseifs], else_block, rest5} + end + + {:keyword, :else, _} -> + {_, rest} = consume(tokens) + + case parse_block(rest) do + {:ok, else_block, rest2} -> + {:ok, [], else_block, rest2} + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, [], nil, tokens} + end + end + + defp parse_while([{:keyword, :while, pos} | rest]) do + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :do), + {:ok, body, rest4} <- parse_block(rest3), + {:ok, _, rest5} <- expect(rest4, :keyword, :end) do + {:ok, %Stmt.While{condition: condition, body: body, meta: Meta.new(pos)}, rest5} + end + end + + defp parse_repeat([{:keyword, :repeat, pos} | rest]) do + with {:ok, body, rest2} <- parse_block(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :until), + {:ok, condition, rest4} <- parse_expr(rest3) do + {:ok, %Stmt.Repeat{body: body, condition: condition, meta: Meta.new(pos)}, rest4} + end + end + + defp parse_for([{:keyword, :for, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, var, _}, rest2} -> + case peek(rest2) do + {:operator, :assign, _} -> + # Numeric for: for var = start, limit, step do ... end + {_, rest3} = consume(rest2) + + with {:ok, start, rest4} <- parse_expr(rest3), + {:ok, _, rest5} <- expect(rest4, :delimiter, :comma), + {:ok, limit, rest6} <- parse_expr(rest5), + {:ok, step, rest7} <- parse_for_step(rest6), + {:ok, _, rest8} <- expect(rest7, :keyword, :do), + {:ok, body, rest9} <- parse_block(rest8), + {:ok, _, rest10} <- expect(rest9, :keyword, :end) do + {:ok, + %Stmt.ForNum{ + var: var, + start: start, + limit: limit, + step: step, + body: body, + meta: Meta.new(pos) + }, rest10} + end + + {:delimiter, :comma, _} -> + # Generic for: for var1, var2 in exprs do ... end + parse_generic_for([var], rest2, pos) + + {:keyword, :in, _} -> + # Generic for with single variable + parse_generic_for([var], rest2, pos) + + _ -> + {:error, {:unexpected_token, peek(rest2), "Expected '=' or 'in' after for variable"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_for_step(tokens) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + parse_expr(rest) + + _ -> + {:ok, nil, tokens} + end + end + + defp parse_generic_for(vars, tokens, start_pos) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, var, _}, rest2} -> + parse_generic_for(vars ++ [var], rest2, start_pos) + + {:error, reason} -> + {:error, reason} + end + + {:keyword, :in, _} -> + {_, rest} = consume(tokens) + + with {:ok, iterators, rest2} <- parse_expr_list(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :do), + {:ok, body, rest4} <- parse_block(rest3), + {:ok, _, rest5} <- expect(rest4, :keyword, :end) do + {:ok, + %Stmt.ForIn{vars: vars, iterators: iterators, body: body, meta: Meta.new(start_pos)}, + rest5} + end + + _ -> + {:error, {:unexpected_token, peek(tokens), "Expected ',' or 'in' in for loop"}} + end + end + + defp parse_function_decl([{:keyword, :function, pos} | rest]) do + case parse_function_name(rest) do + {:ok, name_parts, is_method, rest2} -> + with {:ok, _, rest3} <- expect(rest2, :delimiter, :lparen), + {:ok, params, rest4} <- parse_param_list(rest3), + {:ok, _, rest5} <- expect(rest4, :delimiter, :rparen), + {:ok, body, rest6} <- parse_block(rest5), + {:ok, _, rest7} <- expect(rest6, :keyword, :end) do + {:ok, + %Stmt.FuncDecl{ + name: name_parts, + params: params, + body: body, + is_method: is_method, + meta: Meta.new(pos) + }, rest7} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_function_name(tokens) do + case expect(tokens, :identifier) do + {:ok, {_, name, _}, rest} -> + parse_function_name_rest([name], rest) + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_function_name_rest(names, tokens) do + case peek(tokens) do + {:delimiter, :dot, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + parse_function_name_rest(names ++ [name], rest2) + + {:error, reason} -> + {:error, reason} + end + + {:delimiter, :colon, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + {:ok, names ++ [name], true, rest2} + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, names, false, tokens} + end + end + + defp parse_do([{:keyword, :do, pos} | rest]) do + with {:ok, body, rest2} <- parse_block(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :end) do + {:ok, %Stmt.Do{body: body, meta: Meta.new(pos)}, rest3} + end + end + + defp parse_break([{:keyword, :break, pos} | rest]) do + {:ok, %Stmt.Break{meta: Meta.new(pos)}, rest} + end + + defp parse_goto([{:keyword, :goto, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, label, _}, rest2} -> + {:ok, %Stmt.Goto{label: label, meta: Meta.new(pos)}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_label([{:delimiter, :double_colon, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + case expect(rest2, :delimiter, :double_colon) do + {:ok, _, rest3} -> + {:ok, %Stmt.Label{name: name, meta: Meta.new(pos)}, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assign_or_call(tokens) do + # This is the most complex case - we need to parse a potential lvalue or call + # Start by parsing an expression (which could be a variable, call, property access, etc.) + case parse_expr(tokens) do + {:ok, expr, rest} -> + case peek(rest) do + {:operator, :assign, _} -> + # It's an assignment + parse_assignment([expr], rest) + + {:delimiter, :comma, _} -> + # Multiple targets, must be assignment + parse_assignment_targets([expr], rest) + + _ -> + # It's a call statement (or error if not a call) + case expr do + %Expr.Call{} = call -> + {:ok, %Stmt.CallStmt{call: call, meta: nil}, rest} + + %Expr.MethodCall{} = call -> + {:ok, %Stmt.CallStmt{call: call, meta: nil}, rest} + + _ -> + {:error, + {:unexpected_expression, "Expression statement must be a function call"}} + end + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assignment_targets(targets, [{:delimiter, :comma, _} | rest]) do + case parse_expr(rest) do + {:ok, expr, rest2} -> + case peek(rest2) do + {:delimiter, :comma, _} -> + parse_assignment_targets(targets ++ [expr], rest2) + + {:operator, :assign, _} -> + parse_assignment(targets ++ [expr], rest2) + + _ -> + {:error, {:unexpected_token, peek(rest2), "Expected '=' or ',' in assignment"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assignment(targets, [{:operator, :assign, _} | rest]) do + case parse_expr_list(rest) do + {:ok, values, rest2} -> + {:ok, %Stmt.Assign{targets: targets, values: values, meta: nil}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + # Helper: parse list of names (for local declarations, for loops) + defp parse_name_list(tokens) do + case expect(tokens, :identifier) do + {:ok, {_, name, _}, rest} -> + parse_name_list_rest([name], rest) + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_name_list_rest(names, tokens) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + parse_name_list_rest(names ++ [name], rest2) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, names, tokens} + end + end + + # Expression parsing with Pratt algorithm + @doc """ + Parses an expression with minimum precedence. + """ + @spec parse_expr([token()], non_neg_integer()) :: parse_result(Expr.t()) + def parse_expr(tokens, min_prec \\ 0) do + # Parse prefix (primary or unary operator) + case parse_prefix(tokens) do + {:ok, left, rest} -> + parse_infix(left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + end + + # Parse prefix expressions (primary expressions and unary operators) + defp parse_prefix(tokens) do + case peek(tokens) do + # Literals + {:keyword, :nil, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Nil{meta: Meta.new(pos)}, rest} + + {:keyword, :true, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Bool{value: true, meta: Meta.new(pos)}, rest} + + {:keyword, :false, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Bool{value: false, meta: Meta.new(pos)}, rest} + + {:number, value, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Number{value: value, meta: Meta.new(pos)}, rest} + + {:string, value, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.String{value: value, meta: Meta.new(pos)}, rest} + + # Vararg + {:operator, :vararg, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Vararg{meta: Meta.new(pos)}, rest} + + # Identifier (variable) + {:identifier, name, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Var{name: name, meta: Meta.new(pos)}, rest} + + # Parenthesized expression + {:delimiter, :lparen, _} -> + parse_paren_expr(tokens) + + # Table constructor + {:delimiter, :lbrace, _} -> + parse_table(tokens) + + # Function expression + {:keyword, :function, _} -> + parse_function_expr(tokens) + + # Unary operators + {:keyword, :not, pos} -> + {_, rest} = consume(tokens) + parse_unary(:not, pos, rest) + + {:operator, :sub, pos} -> + {_, rest} = consume(tokens) + parse_unary(:sub, pos, rest) + + {:operator, :len, pos} -> + {_, rest} = consume(tokens) + parse_unary(:len, pos, rest) + + {type, _, pos} -> + {:error, {:unexpected_token, type, pos, "Expected expression"}} + + nil -> + {:error, {:unexpected_end, "Expected expression"}} + end + end + + defp parse_unary(op, pos, tokens) do + unop = Pratt.token_to_unop(op) + prec = Pratt.prefix_binding_power(op) + + case parse_expr(tokens, prec) do + {:ok, operand, rest} -> + {:ok, %Expr.UnOp{op: unop, operand: operand, meta: Meta.new(pos)}, rest} + + {:error, reason} -> + {:error, reason} + end + end + + # Parse infix expressions (binary operators and postfix) + defp parse_infix(left, tokens, min_prec) do + case peek(tokens) do + {:keyword, op, pos} when op in [:and, :or] -> + if Pratt.is_binary_op?(op) do + case Pratt.binding_power(op) do + {left_bp, right_bp} when left_bp >= min_prec -> + {_, rest} = consume(tokens) + binop = Pratt.token_to_binop(op) + + case parse_expr(rest, right_bp) do + {:ok, right, rest2} -> + new_left = %Expr.BinOp{ + op: binop, + left: left, + right: right, + meta: Meta.new(pos) + } + + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + else + {:ok, left, tokens} + end + + {:operator, op, pos} -> + cond do + Pratt.is_binary_op?(op) -> + case Pratt.binding_power(op) do + {left_bp, right_bp} when left_bp >= min_prec -> + {_, rest} = consume(tokens) + binop = Pratt.token_to_binop(op) + + case parse_expr(rest, right_bp) do + {:ok, right, rest2} -> + new_left = %Expr.BinOp{ + op: binop, + left: left, + right: right, + meta: Meta.new(pos) + } + + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + + true -> + {:ok, left, tokens} + end + + # Postfix: function call + {:delimiter, :lparen, _} -> + case parse_call_args(tokens) do + {:ok, args, rest} -> + new_left = %Expr.Call{func: left, args: args, meta: nil} + parse_infix(new_left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + + # Postfix: indexing + {:delimiter, :lbracket, _} -> + case parse_index(tokens) do + {:ok, key, rest} -> + new_left = %Expr.Index{table: left, key: key, meta: nil} + parse_infix(new_left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + + # Postfix: property access or method call + {:delimiter, :dot, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, field, _}, rest2} -> + new_left = %Expr.Property{table: left, field: field, meta: nil} + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + {:delimiter, :colon, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, method, _}, rest2} -> + case parse_call_args(rest2) do + {:ok, args, rest3} -> + new_left = %Expr.MethodCall{object: left, method: method, args: args, meta: nil} + parse_infix(new_left, rest3, min_prec) + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + end + + # Parse parenthesized expression: (expr) + defp parse_paren_expr([{:delimiter, :lparen, _} | rest]) do + case parse_expr(rest) do + {:ok, expr, rest2} -> + case expect(rest2, :delimiter, :rparen) do + {:ok, _, rest3} -> + {:ok, expr, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Parse table constructor: { fields } + defp parse_table([{:delimiter, :lbrace, pos} | rest]) do + case parse_table_fields(rest, []) do + {:ok, fields, rest2} -> + case expect(rest2, :delimiter, :rbrace) do + {:ok, _, rest3} -> + {:ok, %Expr.Table{fields: Enum.reverse(fields), meta: Meta.new(pos)}, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_table_fields(tokens, acc) do + case peek(tokens) do + {:delimiter, :rbrace, _} -> + {:ok, acc, tokens} + + _ -> + case parse_table_field(tokens) do + {:ok, field, rest} -> + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_table_fields(rest2, [field | acc]) + + {:delimiter, :semicolon, _} -> + {_, rest2} = consume(rest) + parse_table_fields(rest2, [field | acc]) + + {:delimiter, :rbrace, _} -> + {:ok, [field | acc], rest} + + _ -> + {:error, {:unexpected_token, peek(rest), "Expected ',' or '}' in table"}} + end + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_table_field(tokens) do + case peek(tokens) do + # [expr] = expr (computed key) + {:delimiter, :lbracket, _} -> + {_, rest} = consume(tokens) + + with {:ok, key, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :delimiter, :rbracket), + {:ok, _, rest4} <- expect(rest3, :operator, :assign), + {:ok, value, rest5} <- parse_expr(rest4) do + {:ok, {:pair, key, value}, rest5} + end + + # name = expr (named field) + {:identifier, name, pos} -> + rest = tl(tokens) + + case peek(rest) do + {:operator, :assign, _} -> + {_, rest2} = consume(rest) + + case parse_expr(rest2) do + {:ok, value, rest3} -> + key = %Expr.String{value: name, meta: Meta.new(pos)} + {:ok, {:pair, key, value}, rest3} + + {:error, reason} -> + {:error, reason} + end + + _ -> + # Just an expression (list entry) + case parse_expr(tokens) do + {:ok, expr, rest2} -> + {:ok, {:list, expr}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + _ -> + # Expression (list entry) + case parse_expr(tokens) do + {:ok, expr, rest} -> + {:ok, {:list, expr}, rest} + + {:error, reason} -> + {:error, reason} + end + end + end + + # Parse function expression: function(params) body end + defp parse_function_expr([{:keyword, :function, pos} | rest]) do + with {:ok, _, rest2} <- expect(rest, :delimiter, :lparen), + {:ok, params, rest3} <- parse_param_list(rest2), + {:ok, _, rest4} <- expect(rest3, :delimiter, :rparen), + {:ok, body, rest5} <- parse_block(rest4), + {:ok, _, rest6} <- expect(rest5, :keyword, :end) do + {:ok, %Expr.Function{params: params, body: body, meta: Meta.new(pos)}, rest6} + end + end + + defp parse_param_list(tokens) do + parse_param_list_acc(tokens, []) + end + + defp parse_param_list_acc(tokens, acc) do + case peek(tokens) do + {:delimiter, :rparen, _} -> + {:ok, Enum.reverse(acc), tokens} + + {:operator, :vararg, _} -> + {_, rest} = consume(tokens) + {:ok, Enum.reverse([:vararg | acc]), rest} + + {:identifier, name, _} -> + {_, rest} = consume(tokens) + + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_param_list_acc(rest2, [name | acc]) + + _ -> + {:ok, Enum.reverse([name | acc]), rest} + end + + _ -> + {:error, {:unexpected_token, peek(tokens), "Expected parameter name or ')'"}} + end + end + + # Parse function call arguments: (args) + defp parse_call_args([{:delimiter, :lparen, _} | rest]) do + parse_expr_list_until(rest, :rparen) + end + + # Parse indexing: [key] + defp parse_index([{:delimiter, :lbracket, _} | rest]) do + case parse_expr(rest) do + {:ok, key, rest2} -> + case expect(rest2, :delimiter, :rbracket) do + {:ok, _, rest3} -> + {:ok, key, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Parse expression list: expr1, expr2, ... + defp parse_expr_list(tokens) do + parse_expr_list_acc(tokens, []) + end + + defp parse_expr_list_acc(tokens, acc) do + case parse_expr(tokens) do + {:ok, expr, rest} -> + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_expr_list_acc(rest2, [expr | acc]) + + _ -> + {:ok, Enum.reverse([expr | acc]), rest} + end + + {:error, reason} -> + if acc == [] do + {:error, reason} + else + {:ok, Enum.reverse(acc), tokens} + end + end + end + + defp parse_expr_list_until(tokens, terminator) do + case peek(tokens) do + {:delimiter, ^terminator, _} -> + {_, rest} = consume(tokens) + {:ok, [], rest} + + _ -> + case parse_expr_list(tokens) do + {:ok, exprs, rest} -> + case expect(rest, :delimiter, terminator) do + {:ok, _, rest2} -> + {:ok, exprs, rest2} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + end + + # Token manipulation helpers + + defp peek([token | _]), do: token + defp peek([]), do: nil + + defp consume([token | rest]), do: {token, rest} + defp consume([]), do: {nil, []} + + # Expect a specific token type + defp expect(tokens, expected_type) do + case peek(tokens) do + {^expected_type, _, _} = token -> + {_, rest} = consume(tokens) + {:ok, token, rest} + + {type, _, pos} when type != nil -> + {:error, + {:unexpected_token, type, pos, "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + + {type, pos} when is_map(pos) -> + # Token without value (like :eof) + {:error, + {:unexpected_token, type, pos, "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + + nil -> + {:error, {:unexpected_end, "Expected #{inspect(expected_type)}"}} + end + end + + # Expect a specific token type and value + defp expect(tokens, expected_type, expected_value) do + case peek(tokens) do + {^expected_type, ^expected_value, _} = token -> + {_, rest} = consume(tokens) + {:ok, token, rest} + + {type, value, pos} when type != nil and value != nil -> + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}:#{inspect(expected_value)}, got #{inspect(type)}:#{inspect(value)}"}} + + {type, pos} when is_map(pos) -> + # Token without value (like :eof) + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}:#{inspect(expected_value)}, got #{inspect(type)}"}} + + nil -> + {:error, {:unexpected_end, "Expected #{inspect(expected_type)}:#{inspect(expected_value)}"}} + end + end + + # Error conversion helpers + + defp convert_error({:unexpected_token, type, pos, message}, _code) do + Error.new(:unexpected_token, message, pos, + suggestion: suggest_for_token_error(type, message) + ) + end + + defp convert_error({:unexpected_end, message}, _code) do + Error.new(:unexpected_end, message, nil, + suggestion: """ + The parser reached the end of the file unexpectedly. + Check for missing closing delimiters or keywords like 'end', ')', '}', or ']'. + """ + ) + end + + defp convert_error({:lexer_error, reason}, code) do + convert_lexer_error(reason, code) + end + + defp convert_error({:not_implemented, feature}, _code) do + Error.new(:invalid_syntax, "Feature not yet implemented: #{feature}", nil) + end + + defp convert_error(other, _code) do + Error.new(:invalid_syntax, "Parse error: #{inspect(other)}", nil) + end + + defp convert_lexer_error({:unexpected_character, char, pos}, _code) do + Error.new(:invalid_syntax, "Unexpected character: #{<>}", pos, + suggestion: """ + This character is not valid in Lua syntax. + Check for typos or invisible characters. + """ + ) + end + + defp convert_lexer_error({:unclosed_string, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed string literal", pos, + suggestion: """ + Add a closing quote (" or ') to finish the string. + Strings cannot span multiple lines unless you use [[...]] syntax. + """ + ) + end + + defp convert_lexer_error({:unclosed_long_string, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed long string [[...]]", pos, + suggestion: "Add the closing ]] to finish the long string." + ) + end + + defp convert_lexer_error({:unclosed_comment, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed multi-line comment --[[...]]", pos, + suggestion: "Add the closing ]] to finish the comment." + ) + end + + defp convert_lexer_error(other, _code) do + Error.new(:lexer_error, "Lexer error: #{inspect(other)}", nil) + end + + defp suggest_for_token_error(type, message) do + cond do + type == :eof -> + "Reached end of file unexpectedly. Check for missing 'end' keywords or closing delimiters." + + String.contains?(message, "Expected 'end'") -> + """ + Every block needs an 'end': + - if/elseif/else ... end + - while ... do ... end + - for ... do ... end + - function ... end + - do ... end + """ + + String.contains?(message, "Expected 'then'") -> + "In Lua, 'if' and 'elseif' conditions must be followed by 'then'." + + String.contains?(message, "Expected 'do'") -> + "In Lua, 'while' and 'for' loops must have 'do' before the body." + + true -> + nil + end + end +end diff --git a/lib/lua/parser/error.ex b/lib/lua/parser/error.ex new file mode 100644 index 0000000..ec9345c --- /dev/null +++ b/lib/lua/parser/error.ex @@ -0,0 +1,344 @@ +defmodule Lua.Parser.Error do + @moduledoc """ + Beautiful error reporting for the Lua parser. + + Provides detailed error messages with: + - Source code context with line numbers + - Visual indicators pointing to the error location + - Helpful suggestions for common mistakes + - Multiple error reporting + """ + + alias Lua.AST.Meta + + @type position :: Meta.position() + + @type t :: %__MODULE__{ + type: error_type(), + message: String.t(), + position: position() | nil, + suggestion: String.t() | nil, + source_lines: [String.t()], + related: [t()] + } + + @type error_type :: + :unexpected_token + | :unexpected_end + | :expected_token + | :unclosed_delimiter + | :invalid_syntax + | :lexer_error + | :multiple_errors + + defstruct [ + :type, + :message, + :position, + :suggestion, + source_lines: [], + related: [] + ] + + @doc """ + Creates a new error. + """ + @spec new(error_type(), String.t(), position() | nil, keyword()) :: t() + def new(type, message, position \\ nil, opts \\ []) do + %__MODULE__{ + type: type, + message: message, + position: position, + suggestion: opts[:suggestion], + source_lines: opts[:source_lines] || [], + related: opts[:related] || [] + } + end + + @doc """ + Creates an error for unexpected token. + """ + @spec unexpected_token(atom(), term(), position(), String.t()) :: t() + def unexpected_token(token_type, token_value, position, context) do + message = """ + Unexpected #{format_token(token_type, token_value)} in #{context} + """ + + suggestion = suggest_for_unexpected_token(token_type, token_value, context) + + new(:unexpected_token, message, position, suggestion: suggestion) + end + + @doc """ + Creates an error for expected token. + """ + @spec expected_token(atom(), term() | nil, atom(), term(), position()) :: t() + def expected_token(expected_type, expected_value, got_type, got_value, position) do + expected = format_token(expected_type, expected_value) + got = format_token(got_type, got_value) + + message = """ + Expected #{expected}, but got #{got} + """ + + suggestion = suggest_for_expected_token(expected_type, expected_value, got_type) + + new(:expected_token, message, position, suggestion: suggestion) + end + + @doc """ + Creates an error for unclosed delimiter. + """ + @spec unclosed_delimiter(atom(), position(), position() | nil) :: t() + def unclosed_delimiter(delimiter, open_pos, close_pos \\ nil) do + delimiter_str = format_delimiter(delimiter) + + message = """ + Unclosed #{delimiter_str} + """ + + suggestion = """ + Add a closing #{closing_delimiter(delimiter)} to match the opening at line #{open_pos.line} + """ + + new(:unclosed_delimiter, message, close_pos || open_pos, suggestion: suggestion) + end + + @doc """ + Creates an error for unexpected end of input. + """ + @spec unexpected_end(String.t(), position() | nil) :: t() + def unexpected_end(context, position \\ nil) do + message = """ + Unexpected end of input while parsing #{context} + """ + + suggestion = """ + Check for missing closing delimiters or keywords like 'end', ')', '}', or ']' + """ + + new(:unexpected_end, message, position, suggestion: suggestion) + end + + @doc """ + Formats an error into a beautiful multi-line string with context. + """ + @spec format(t(), String.t()) :: String.t() + def format(error, source_code) do + lines = String.split(source_code, "\n") + + header = [ + IO.ANSI.red() <> IO.ANSI.bright() <> "Parse Error" <> IO.ANSI.reset(), + "" + ] + + location = + if error.position do + pos = error.position + " at line #{pos.line}, column #{pos.column}:" + else + " (no position information)" + end + + message_lines = [ + location, + "", + indent(error.message, 2) + ] + + context_lines = + if error.position && length(lines) > 0 do + format_context(lines, error.position) + else + [] + end + + suggestion_lines = + if error.suggestion do + [ + "", + IO.ANSI.cyan() <> "Suggestion:" <> IO.ANSI.reset(), + indent(error.suggestion, 2) + ] + else + [] + end + + related_lines = + if length(error.related) > 0 do + [ + "", + IO.ANSI.yellow() <> "Related errors:" <> IO.ANSI.reset() + ] ++ Enum.flat_map(error.related, fn rel -> ["", indent(format(rel, source_code), 2)] end) + else + [] + end + + (header ++ message_lines ++ context_lines ++ suggestion_lines ++ related_lines) + |> Enum.join("\n") + end + + @doc """ + Formats multiple errors together. + """ + @spec format_multiple([t()], String.t()) :: String.t() + def format_multiple(errors, source_code) do + header = [ + IO.ANSI.red() <> IO.ANSI.bright() <> + "Found #{length(errors)} parse error#{if length(errors) == 1, do: "", else: "s"}" <> + IO.ANSI.reset(), + "" + ] + + error_lines = + errors + |> Enum.with_index(1) + |> Enum.flat_map(fn {error, idx} -> + [ + IO.ANSI.yellow() <> "Error #{idx}:" <> IO.ANSI.reset(), + format(error, source_code), + "" + ] + end) + + (header ++ error_lines) + |> Enum.join("\n") + end + + # Private helpers + + defp format_context(lines, position) do + line_num = position.line + column = position.column + + # Show 2 lines before and after + start_line = max(1, line_num - 2) + end_line = min(length(lines), line_num + 2) + + context_lines = + Enum.slice(lines, (start_line - 1)..(end_line - 1)) + |> Enum.with_index(start_line) + |> Enum.flat_map(fn {line, num} -> + line_str = format_line_number(num) <> " │ " <> line + + if num == line_num do + # Error line + pointer = String.duplicate(" ", String.length(format_line_number(num)) + 3 + column - 1) + pointer = pointer <> IO.ANSI.red() <> "^" <> IO.ANSI.reset() + + [ + IO.ANSI.red() <> line_str <> IO.ANSI.reset(), + pointer + ] + else + # Context line + [IO.ANSI.faint() <> line_str <> IO.ANSI.reset()] + end + end) + + ["", ""] ++ context_lines + end + + defp format_line_number(num) do + num + |> Integer.to_string() + |> String.pad_leading(4) + end + + defp format_token(type, value) do + case type do + :keyword -> "'#{value}'" + :identifier -> "identifier '#{value}'" + :number -> "number #{value}" + :string -> "string \"#{value}\"" + :operator -> "operator '#{value}'" + :delimiter -> "'#{value}'" + :eof -> "end of input" + _ -> "#{type}" + end + end + + defp format_delimiter(delimiter) do + case delimiter do + :lparen -> "opening parenthesis '('" + :lbracket -> "opening bracket '['" + :lbrace -> "opening brace '{'" + :function -> "'function' block" + :if -> "'if' statement" + :while -> "'while' loop" + :for -> "'for' loop" + :do -> "'do' block" + _ -> "#{delimiter}" + end + end + + defp closing_delimiter(delimiter) do + case delimiter do + :lparen -> "')'" + :lbracket -> "']'" + :lbrace -> "'}'" + :function -> "'end'" + :if -> "'end'" + :while -> "'end'" + :for -> "'end'" + :do -> "'end'" + _ -> "matching delimiter" + end + end + + defp suggest_for_unexpected_token(token_type, _token_value, context) do + cond do + token_type == :delimiter -> + "Check for missing operators or keywords before this delimiter" + + String.contains?(context, "expression") -> + "Expected an expression here (variable, number, string, table, function, etc.)" + + String.contains?(context, "statement") -> + "Expected a statement here (assignment, function call, if, while, for, etc.)" + + true -> + nil + end + end + + defp suggest_for_expected_token(expected_type, expected_value, got_type) do + cond do + expected_type == :keyword && expected_value == :end -> + "Add 'end' to close the block. Check that all opening keywords (if, while, for, function, do) have matching 'end' keywords." + + expected_type == :keyword && expected_value == :then -> + "Add 'then' after the condition. Lua requires 'then' after if/elseif conditions." + + expected_type == :keyword && expected_value == :do -> + "Add 'do' to start the loop body. Lua requires 'do' after while/for conditions." + + expected_type == :delimiter && expected_value == :rparen -> + "Add ')' to close the parentheses. Check for balanced parentheses." + + expected_type == :delimiter && expected_value == :rbracket -> + "Add ']' to close the brackets. Check for balanced brackets." + + expected_type == :delimiter && expected_value == :rbrace -> + "Add '}' to close the table constructor. Check for balanced braces." + + expected_type == :operator && expected_value == :assign -> + "Add '=' for assignment. Did you mean to assign a value?" + + expected_type == :identifier && got_type == :keyword -> + "Cannot use Lua keyword as identifier. Choose a different name." + + true -> + nil + end + end + + defp indent(text, spaces) do + prefix = String.duplicate(" ", spaces) + + text + |> String.split("\n") + |> Enum.map(&(prefix <> &1)) + |> Enum.join("\n") + end +end diff --git a/lib/lua/parser/pratt.ex b/lib/lua/parser/pratt.ex new file mode 100644 index 0000000..9ce6d4c --- /dev/null +++ b/lib/lua/parser/pratt.ex @@ -0,0 +1,130 @@ +defmodule Lua.Parser.Pratt do + @moduledoc """ + Pratt parser for Lua expressions. + + Implements operator precedence parsing using binding powers. + Handles all 11 precedence levels in Lua 5.3. + + Precedence (lowest to highest): + 1. or + 2. and + 3. < > <= >= ~= == + 4. .. + 5. + - + 6. * / // % + 7. unary (not # -) + 8. ^ + """ + + alias Lua.AST.Expr + + @doc """ + Returns the binding power (precedence) for binary operators. + + Returns {left_bp, right_bp} where: + - left_bp: minimum precedence of left operand + - right_bp: minimum precedence of right operand + + Right associative operators have left_bp < right_bp. + Left associative operators have left_bp >= right_bp. + """ + @spec binding_power(atom()) :: {non_neg_integer(), non_neg_integer()} | nil + def binding_power(:or), do: {1, 2} + def binding_power(:and), do: {3, 4} + + # Comparison operators (left associative) + def binding_power(:lt), do: {5, 6} + def binding_power(:gt), do: {5, 6} + def binding_power(:le), do: {5, 6} + def binding_power(:ge), do: {5, 6} + def binding_power(:ne), do: {5, 6} + def binding_power(:eq), do: {5, 6} + + # String concatenation (right associative) + def binding_power(:concat), do: {7, 6} + + # Additive (left associative) + def binding_power(:add), do: {9, 10} + def binding_power(:sub), do: {9, 10} + + # Multiplicative (left associative) + def binding_power(:mul), do: {11, 12} + def binding_power(:div), do: {11, 12} + def binding_power(:floordiv), do: {11, 12} + def binding_power(:mod), do: {11, 12} + + # Unary operators + def binding_power(:not), do: {13, 14} + def binding_power(:neg), do: {13, 14} + def binding_power(:len), do: {13, 14} + + # Power (right associative) + def binding_power(:pow), do: {16, 15} + + # Not a binary operator + def binding_power(_), do: nil + + @doc """ + Returns the binding power for unary prefix operators. + + This is the minimum precedence required for the operand. + + Note: In Lua, unary minus has an unusual precedence - it's lower than power (^). + So -2^3 = -(2^3), not (-2)^3. + To achieve this: unary minus binding power (13) < power left_bp (16), + allowing power to bind within the unary's operand. + But 13 > multiplication left_bp (11), so -a*b = (-a)*b. + """ + @spec prefix_binding_power(atom()) :: non_neg_integer() | nil + def prefix_binding_power(:not), do: 14 + def prefix_binding_power(:sub), do: 13 # Between mult (11) and power (16) + def prefix_binding_power(:len), do: 14 + def prefix_binding_power(_), do: nil + + @doc """ + Maps token operators to AST binary operators. + """ + @spec token_to_binop(atom()) :: Expr.BinOp.op() | nil + def token_to_binop(:or), do: :or + def token_to_binop(:and), do: :and + def token_to_binop(:lt), do: :lt + def token_to_binop(:gt), do: :gt + def token_to_binop(:le), do: :le + def token_to_binop(:ge), do: :ge + def token_to_binop(:ne), do: :ne + def token_to_binop(:eq), do: :eq + def token_to_binop(:concat), do: :concat + def token_to_binop(:add), do: :add + def token_to_binop(:sub), do: :sub + def token_to_binop(:mul), do: :mul + def token_to_binop(:div), do: :div + def token_to_binop(:floordiv), do: :floordiv + def token_to_binop(:mod), do: :mod + def token_to_binop(:pow), do: :pow + def token_to_binop(_), do: nil + + @doc """ + Maps token operators to AST unary operators. + """ + @spec token_to_unop(atom()) :: Expr.UnOp.op() | nil + def token_to_unop(:not), do: :not + def token_to_unop(:sub), do: :neg + def token_to_unop(:len), do: :len + def token_to_unop(_), do: nil + + @doc """ + Checks if a token is a binary operator. + """ + @spec is_binary_op?(atom()) :: boolean() + def is_binary_op?(op) do + binding_power(op) != nil + end + + @doc """ + Checks if a token is a prefix unary operator. + """ + @spec is_prefix_op?(atom()) :: boolean() + def is_prefix_op?(op) do + prefix_binding_power(op) != nil + end +end diff --git a/lib/lua/parser/recovery.ex b/lib/lua/parser/recovery.ex new file mode 100644 index 0000000..0ae193b --- /dev/null +++ b/lib/lua/parser/recovery.ex @@ -0,0 +1,208 @@ +defmodule Lua.Parser.Recovery do + @moduledoc """ + Error recovery strategies for the Lua parser. + + Allows the parser to continue after encountering errors, + collecting multiple errors in a single parse pass. + """ + + alias Lua.Parser.Error + alias Lua.Lexer + + @type token :: Lexer.token() + @type recovery_result :: {:recovered, [token()], [Error.t()]} | {:failed, [Error.t()]} + + @doc """ + Attempts to recover from a parse error by finding a synchronization point. + + Synchronization points are tokens where we can safely resume parsing: + - Statement boundaries: `;`, `end`, `else`, `elseif`, `until` + - Block terminators: `}`, `)` + - Start of new statements: keywords like `if`, `while`, `for`, `function`, `local` + """ + @spec recover_at_statement([token()], Error.t()) :: recovery_result() + def recover_at_statement(tokens, error) do + case find_statement_boundary(tokens) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + {:failed, [error]} + end + end + + @doc """ + Recovers from an unclosed delimiter by finding the matching closing delimiter. + """ + @spec recover_unclosed_delimiter([token()], atom(), Error.t()) :: recovery_result() + def recover_unclosed_delimiter(tokens, delimiter_type, error) do + closing = closing_delimiter(delimiter_type) + + case find_closing_delimiter(tokens, closing, 1) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + # If we can't find the closing delimiter, try to recover at statement boundary + recover_at_statement(tokens, error) + end + end + + @doc """ + Attempts to recover from missing keyword error. + """ + @spec recover_missing_keyword([token()], atom(), Error.t()) :: recovery_result() + def recover_missing_keyword(tokens, keyword, error) do + case find_keyword(tokens, keyword) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + recover_at_statement(tokens, error) + end + end + + @doc """ + Skips tokens until we find a valid statement start. + """ + @spec skip_to_statement([token()]) :: [token()] + def skip_to_statement(tokens) do + case find_statement_boundary(tokens) do + {:ok, rest} -> rest + :not_found -> [] + end + end + + @doc """ + Checks if a token is a statement boundary (synchronization point). + """ + @spec is_statement_boundary?(token()) :: boolean() + def is_statement_boundary?(token) do + case token do + {:delimiter, :semicolon, _} -> true + {:keyword, kw, _} when kw in [:end, :else, :elseif, :until] -> true + {:keyword, kw, _} when kw in [:if, :while, :for, :function, :local, :do, :repeat] -> true + {:eof, _} -> true + _ -> false + end + end + + defmodule DelimiterStack do + @moduledoc """ + Tracks unclosed delimiters in a stack-based manner. + """ + + defstruct stack: [] + + @type t :: %__MODULE__{stack: [{atom(), Meta.position()}]} + + def new, do: %__MODULE__{} + + def push(stack, delimiter, position) do + %{stack | stack: [{delimiter, position} | stack.stack]} + end + + def pop(stack, closing_delimiter) do + case stack.stack do + [{opening, _pos} | rest] -> + if matches?(opening, closing_delimiter) do + {:ok, %{stack | stack: rest}} + else + {:error, :mismatched, opening} + end + + [] -> + {:error, :empty} + end + end + + def peek(stack) do + case stack.stack do + [{delimiter, position} | _] -> {:ok, delimiter, position} + [] -> :empty + end + end + + def empty?(stack), do: stack.stack == [] + + defp matches?(opening, closing) do + case {opening, closing} do + {:lparen, :rparen} -> true + {:lbracket, :rbracket} -> true + {:lbrace, :rbrace} -> true + {:function, :end} -> true + {:if, :end} -> true + {:while, :end} -> true + {:for, :end} -> true + {:do, :end} -> true + _ -> false + end + end + end + + # Private helpers + + defp find_statement_boundary([token | rest]) do + if is_statement_boundary?(token) do + {:ok, [token | rest]} + else + find_statement_boundary(rest) + end + end + + defp find_statement_boundary([]), do: :not_found + + defp find_closing_delimiter([{:delimiter, delim, _} | rest], target, depth) + when delim == target do + if depth == 1 do + {:ok, rest} + else + find_closing_delimiter(rest, target, depth - 1) + end + end + + defp find_closing_delimiter([{:delimiter, opening, _} | rest], target, depth) + when opening in [:lparen, :lbracket, :lbrace] do + find_closing_delimiter(rest, target, depth + 1) + end + + defp find_closing_delimiter([{:keyword, :end, _} | rest], :end, depth) do + if depth == 1 do + {:ok, rest} + else + find_closing_delimiter(rest, :end, depth - 1) + end + end + + defp find_closing_delimiter([{:keyword, kw, _} | rest], :end, depth) + when kw in [:if, :while, :for, :function, :do] do + find_closing_delimiter(rest, :end, depth + 1) + end + + defp find_closing_delimiter([{:eof, _}], _target, _depth), do: :not_found + defp find_closing_delimiter([], _target, _depth), do: :not_found + + defp find_closing_delimiter([_ | rest], target, depth) do + find_closing_delimiter(rest, target, depth) + end + + defp find_keyword([{:keyword, kw, _} | _rest] = tokens, target) when kw == target do + {:ok, tokens} + end + + defp find_keyword([{:eof, _}], _target), do: :not_found + defp find_keyword([], _target), do: :not_found + + defp find_keyword([_ | rest], target) do + find_keyword(rest, target) + end + + defp closing_delimiter(delimiter) do + case delimiter do + :lparen -> :rparen + :lbracket -> :rbracket + :lbrace -> :rbrace + _ -> :end + end + end +end diff --git a/test/lua/ast/builder_test.exs b/test/lua/ast/builder_test.exs new file mode 100644 index 0000000..c33476f --- /dev/null +++ b/test/lua/ast/builder_test.exs @@ -0,0 +1,481 @@ +defmodule Lua.AST.BuilderTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.{Chunk, Block, Expr, Stmt} + + describe "chunk and block" do + test "creates a chunk" do + ast = chunk([local(["x"], [number(42)])]) + assert %Chunk{block: %Block{stmts: [%Stmt.Local{}]}} = ast + end + + test "creates a block" do + blk = block([local(["x"], [number(42)])]) + assert %Block{stmts: [%Stmt.Local{}]} = blk + end + end + + describe "literals" do + test "creates nil literal" do + assert %Expr.Nil{} = nil_lit() + end + + test "creates boolean literals" do + assert %Expr.Bool{value: true} = bool(true) + assert %Expr.Bool{value: false} = bool(false) + end + + test "creates number literal" do + assert %Expr.Number{value: 42} = number(42) + assert %Expr.Number{value: 3.14} = number(3.14) + end + + test "creates string literal" do + assert %Expr.String{value: "hello"} = string("hello") + end + + test "creates vararg" do + assert %Expr.Vararg{} = vararg() + end + end + + describe "variables and access" do + test "creates variable reference" do + assert %Expr.Var{name: "x"} = var("x") + end + + test "creates property access" do + prop = property(var("io"), "write") + assert %Expr.Property{table: %Expr.Var{name: "io"}, field: "write"} = prop + end + + test "creates index access" do + idx = index(var("t"), number(1)) + assert %Expr.Index{table: %Expr.Var{name: "t"}, key: %Expr.Number{value: 1}} = idx + end + + test "creates chained property access" do + prop = property(property(var("a"), "b"), "c") + assert %Expr.Property{ + table: %Expr.Property{ + table: %Expr.Var{name: "a"}, + field: "b" + }, + field: "c" + } = prop + end + end + + describe "operators" do + test "creates binary operation" do + op = binop(:add, number(2), number(3)) + assert %Expr.BinOp{op: :add, left: %Expr.Number{value: 2}, right: %Expr.Number{value: 3}} = op + end + + test "creates all binary operators" do + ops = [:add, :sub, :mul, :div, :floor_div, :mod, :pow, :concat, :eq, :ne, :lt, :gt, :le, :ge, :and, :or] + + for op <- ops do + assert %Expr.BinOp{op: ^op} = binop(op, number(1), number(2)) + end + end + + test "creates unary operation" do + op = unop(:neg, var("x")) + assert %Expr.UnOp{op: :neg, operand: %Expr.Var{name: "x"}} = op + end + + test "creates all unary operators" do + assert %Expr.UnOp{op: :not} = unop(:not, var("x")) + assert %Expr.UnOp{op: :neg} = unop(:neg, var("x")) + assert %Expr.UnOp{op: :len} = unop(:len, var("x")) + end + + test "creates nested operations" do + # (2 + 3) * 4 + op = binop(:mul, binop(:add, number(2), number(3)), number(4)) + assert %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 4} + } = op + end + end + + describe "table constructors" do + test "creates empty table" do + tbl = table([]) + assert %Expr.Table{fields: []} = tbl + end + + test "creates array-style table" do + tbl = table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + assert %Expr.Table{fields: [{:list, _}, {:list, _}, {:list, _}]} = tbl + end + + test "creates record-style table" do + tbl = table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + assert %Expr.Table{fields: [{:record, _, _}, {:record, _, _}]} = tbl + end + + test "creates mixed table" do + tbl = table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + assert %Expr.Table{fields: [{:list, _}, {:record, _, _}]} = tbl + end + end + + describe "function calls" do + test "creates function call" do + c = call(var("print"), [string("hello")]) + assert %Expr.Call{ + func: %Expr.Var{name: "print"}, + args: [%Expr.String{value: "hello"}] + } = c + end + + test "creates function call with multiple arguments" do + c = call(var("print"), [number(1), number(2), number(3)]) + assert %Expr.Call{args: [_, _, _]} = c + end + + test "creates method call" do + mc = method_call(var("file"), "read", [string("*a")]) + assert %Expr.MethodCall{ + object: %Expr.Var{name: "file"}, + method: "read", + args: [%Expr.String{value: "*a"}] + } = mc + end + end + + describe "function expressions" do + test "creates simple function" do + fn_expr = function_expr(["x"], [return_stmt([var("x")])]) + assert %Expr.Function{ + params: ["x"], + body: %Block{stmts: [%Stmt.Return{}]} + } = fn_expr + end + + test "creates function with multiple parameters" do + fn_expr = function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Expr.Function{params: ["a", "b"]} = fn_expr + end + + test "creates function with vararg" do + fn_expr = function_expr([], [return_stmt([vararg()])], vararg: true) + assert %Expr.Function{params: [:vararg]} = fn_expr + end + end + + describe "statements" do + test "creates assignment" do + stmt = assign([var("x")], [number(42)]) + assert %Stmt.Assign{ + targets: [%Expr.Var{name: "x"}], + values: [%Expr.Number{value: 42}] + } = stmt + end + + test "creates multiple assignment" do + stmt = assign([var("x"), var("y")], [number(1), number(2)]) + assert %Stmt.Assign{targets: [_, _], values: [_, _]} = stmt + end + + test "creates local declaration" do + stmt = local(["x"], [number(42)]) + assert %Stmt.Local{names: ["x"], values: [%Expr.Number{value: 42}]} = stmt + end + + test "creates local declaration without value" do + stmt = local(["x"], []) + assert %Stmt.Local{names: ["x"], values: []} = stmt + end + + test "creates local function" do + stmt = local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Stmt.LocalFunc{ + name: "add", + params: ["a", "b"], + body: %Block{} + } = stmt + end + + test "creates function declaration with string name" do + stmt = func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Stmt.FuncDecl{name: ["add"], params: ["a", "b"]} = stmt + end + + test "creates function declaration with path name" do + stmt = func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Stmt.FuncDecl{name: ["math", "add"]} = stmt + end + + test "creates call statement" do + stmt = call_stmt(call(var("print"), [string("hello")])) + assert %Stmt.CallStmt{call: %Expr.Call{}} = stmt + end + + test "creates return statement" do + stmt = return_stmt([]) + assert %Stmt.Return{values: []} = stmt + + stmt = return_stmt([number(42)]) + assert %Stmt.Return{values: [%Expr.Number{value: 42}]} = stmt + end + + test "creates break statement" do + stmt = break_stmt() + assert %Stmt.Break{} = stmt + end + + test "creates goto statement" do + stmt = goto_stmt("label") + assert %Stmt.Goto{label: "label"} = stmt + end + + test "creates label" do + stmt = label("label") + assert %Stmt.Label{name: "label"} = stmt + end + end + + describe "control flow" do + test "creates if statement" do + stmt = if_stmt(var("x"), [return_stmt([number(1)])]) + assert %Stmt.If{ + condition: %Expr.Var{name: "x"}, + then_block: %Block{stmts: [%Stmt.Return{}]}, + elseifs: [], + else_block: nil + } = stmt + end + + test "creates if-else statement" do + stmt = if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + assert %Stmt.If{else_block: %Block{}} = stmt + end + + test "creates if-elseif-else statement" do + stmt = if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + assert %Stmt.If{ + elseifs: [{_, %Block{}}], + else_block: %Block{} + } = stmt + end + + test "creates while loop" do + stmt = while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) + assert %Stmt.While{ + condition: %Expr.BinOp{op: :gt}, + body: %Block{} + } = stmt + end + + test "creates repeat-until loop" do + stmt = repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + assert %Stmt.Repeat{ + body: %Block{}, + condition: %Expr.BinOp{op: :le} + } = stmt + end + + test "creates numeric for loop" do + stmt = for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + assert %Stmt.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: nil, + body: %Block{} + } = stmt + end + + test "creates numeric for loop with step" do + stmt = for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ], step: number(2)) + assert %Stmt.ForNum{step: %Expr.Number{value: 2}} = stmt + end + + test "creates generic for loop" do + stmt = for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + assert %Stmt.ForIn{ + vars: ["k", "v"], + iterators: [%Expr.Call{}], + body: %Block{} + } = stmt + end + + test "creates do block" do + stmt = do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + assert %Stmt.Do{body: %Block{stmts: [_, _]}} = stmt + end + end + + describe "complex structures" do + test "builds nested function with closure" do + # function outer(x) return function(y) return x + y end end + ast = chunk([ + func_decl("outer", ["x"], [ + return_stmt([ + function_expr(["y"], [ + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + ]) + ]) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Stmt.FuncDecl{ + name: ["outer"], + body: %Block{ + stmts: [ + %Stmt.Return{ + values: [%Expr.Function{}] + } + ] + } + } + ] + } + } = ast + end + + test "builds complex if-elseif-else chain" do + ast = chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([string("positive")])], + elseif: [ + {binop(:lt, var("x"), number(0)), [return_stmt([string("negative")])]}, + {binop(:eq, var("x"), number(0)), [return_stmt([string("zero")])]} + ], + else: [return_stmt([string("unknown")])] + ) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Stmt.If{ + elseifs: [{_, _}, {_, _}], + else_block: %Block{} + } + ] + } + } = ast + end + + test "builds nested loops" do + # for i = 1, 10 do + # for j = 1, 10 do + # print(i * j) + # end + # end + ast = chunk([ + for_num("i", number(1), number(10), [ + for_num("j", number(1), number(10), [ + call_stmt(call(var("print"), [binop(:mul, var("i"), var("j"))])) + ]) + ]) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Stmt.ForNum{ + body: %Block{ + stmts: [%Stmt.ForNum{}] + } + } + ] + } + } = ast + end + + test "builds table with complex expressions" do + # { + # x = 1 + 2, + # y = func(), + # [key] = value, + # nested = {a = 1, b = 2} + # } + tbl = table([ + {:record, string("x"), binop(:add, number(1), number(2))}, + {:record, string("y"), call(var("func"), [])}, + {:record, var("key"), var("value")}, + {:record, string("nested"), table([ + {:record, string("a"), number(1)}, + {:record, string("b"), number(2)} + ])} + ]) + + assert %Expr.Table{ + fields: [ + {:record, %Expr.String{value: "x"}, %Expr.BinOp{}}, + {:record, %Expr.String{value: "y"}, %Expr.Call{}}, + {:record, %Expr.Var{}, %Expr.Var{}}, + {:record, %Expr.String{value: "nested"}, %Expr.Table{}} + ] + } = tbl + end + end + + describe "integration with parser" do + test "builder output can be printed and reparsed" do + # Build an AST using builder + ast = chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]), + assign([var("z")], [binop(:add, var("x"), var("y"))]), + call_stmt(call(var("print"), [var("z")])) + ]) + + # Print it + code = Lua.AST.PrettyPrinter.print(ast) + + # Parse it back + {:ok, reparsed} = Lua.Parser.parse(code) + + # Should have same structure (ignoring meta) + assert length(ast.block.stmts) == length(reparsed.block.stmts) + end + end +end diff --git a/test/lua/ast/pretty_printer_test.exs b/test/lua/ast/pretty_printer_test.exs new file mode 100644 index 0000000..c812f81 --- /dev/null +++ b/test/lua/ast/pretty_printer_test.exs @@ -0,0 +1,425 @@ +defmodule Lua.AST.PrettyPrinterTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.PrettyPrinter + + describe "literals" do + test "prints nil" do + assert PrettyPrinter.print(chunk([return_stmt([nil_lit()])])) == "return nil\n" + end + + test "prints booleans" do + assert PrettyPrinter.print(chunk([return_stmt([bool(true)])])) == "return true\n" + assert PrettyPrinter.print(chunk([return_stmt([bool(false)])])) == "return false\n" + end + + test "prints numbers" do + assert PrettyPrinter.print(chunk([return_stmt([number(42)])])) == "return 42\n" + assert PrettyPrinter.print(chunk([return_stmt([number(3.14)])])) == "return 3.14\n" + assert PrettyPrinter.print(chunk([return_stmt([number(1.0)])])) == "return 1.0\n" + end + + test "prints strings" do + assert PrettyPrinter.print(chunk([return_stmt([string("hello")])])) == "return \"hello\"\n" + end + + test "escapes special characters in strings" do + ast = chunk([return_stmt([string("hello\nworld")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"hello\\nworld\"\n" + end + + test "prints vararg" do + assert PrettyPrinter.print(chunk([return_stmt([vararg()])])) == "return ...\n" + end + end + + describe "variables and access" do + test "prints variable reference" do + assert PrettyPrinter.print(chunk([return_stmt([var("x")])])) == "return x\n" + end + + test "prints property access" do + ast = chunk([return_stmt([property(var("io"), "write")])]) + assert PrettyPrinter.print(ast) == "return io.write\n" + end + + test "prints index access" do + ast = chunk([return_stmt([index(var("t"), number(1))])]) + assert PrettyPrinter.print(ast) == "return t[1]\n" + end + + test "prints chained property access" do + ast = chunk([return_stmt([ + property(property(var("a"), "b"), "c") + ])]) + assert PrettyPrinter.print(ast) == "return a.b.c\n" + end + end + + describe "operators" do + test "prints binary operators" do + ast = chunk([return_stmt([binop(:add, number(2), number(3))])]) + assert PrettyPrinter.print(ast) == "return 2 + 3\n" + end + + test "prints unary operators" do + ast = chunk([return_stmt([unop(:neg, var("x"))])]) + assert PrettyPrinter.print(ast) == "return -x\n" + + ast = chunk([return_stmt([unop(:not, var("flag"))])]) + assert PrettyPrinter.print(ast) == "return not flag\n" + + ast = chunk([return_stmt([unop(:len, var("list"))])]) + assert PrettyPrinter.print(ast) == "return #list\n" + end + + test "handles operator precedence with parentheses" do + # 2 + 3 * 4 should print as is (multiplication has higher precedence) + ast = chunk([return_stmt([ + binop(:add, number(2), binop(:mul, number(3), number(4))) + ])]) + assert PrettyPrinter.print(ast) == "return 2 + 3 * 4\n" + + # (2 + 3) * 4 should have parentheses + ast = chunk([return_stmt([ + binop(:mul, binop(:add, number(2), number(3)), number(4)) + ])]) + assert PrettyPrinter.print(ast) == "return (2 + 3) * 4\n" + end + + test "handles right-associative operators" do + # 2 ^ 3 ^ 4 should print as 2 ^ 3 ^ 4 (right-associative) + ast = chunk([return_stmt([ + binop(:pow, number(2), binop(:pow, number(3), number(4))) + ])]) + assert PrettyPrinter.print(ast) == "return 2 ^ 3 ^ 4\n" + + # (2 ^ 3) ^ 4 should have parentheses + ast = chunk([return_stmt([ + binop(:pow, binop(:pow, number(2), number(3)), number(4)) + ])]) + assert PrettyPrinter.print(ast) == "return (2 ^ 3) ^ 4\n" + end + + test "handles unary minus with power" do + # -2^3 should print with parentheses as -(2^3) is not needed because parser handles it + ast = chunk([return_stmt([unop(:neg, binop(:pow, number(2), number(3)))])]) + result = PrettyPrinter.print(ast) + # Either -2^3 or -(2^3) is acceptable + assert result == "return -(2 ^ 3)\n" or result == "return -2 ^ 3\n" + end + end + + describe "table constructors" do + test "prints empty table" do + ast = chunk([return_stmt([table([])])]) + assert PrettyPrinter.print(ast) == "return {}\n" + end + + test "prints array-style table" do + ast = chunk([return_stmt([ + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + ])]) + assert PrettyPrinter.print(ast) == "return {1, 2, 3}\n" + end + + test "prints record-style table" do + ast = chunk([return_stmt([ + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + ])]) + assert PrettyPrinter.print(ast) == "return {x = 10, y = 20}\n" + end + + test "prints mixed table fields" do + ast = chunk([return_stmt([ + table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + ])]) + assert PrettyPrinter.print(ast) == "return {1, x = 10}\n" + end + end + + describe "function calls" do + test "prints simple function call" do + ast = chunk([call_stmt(call(var("print"), [string("hello")]))]) + assert PrettyPrinter.print(ast) == "print(\"hello\")\n" + end + + test "prints function call with multiple arguments" do + ast = chunk([call_stmt(call(var("print"), [number(1), number(2), number(3)]))]) + assert PrettyPrinter.print(ast) == "print(1, 2, 3)\n" + end + + test "prints method call" do + ast = chunk([call_stmt(method_call(var("file"), "read", [string("*a")]))]) + assert PrettyPrinter.print(ast) == "file:read(\"*a\")\n" + end + end + + describe "function expressions" do + test "prints simple function" do + ast = chunk([return_stmt([ + function_expr(["x"], [return_stmt([var("x")])]) + ])]) + result = PrettyPrinter.print(ast) + assert result =~ "function(x)" + assert result =~ "return x" + assert result =~ "end" + end + + test "prints function with multiple parameters" do + ast = chunk([return_stmt([ + function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + ])]) + result = PrettyPrinter.print(ast) + assert result =~ "function(a, b)" + assert result =~ "return a + b" + end + + test "prints function with vararg" do + ast = chunk([return_stmt([ + function_expr([], [return_stmt([vararg()])], vararg: true) + ])]) + result = PrettyPrinter.print(ast) + assert result =~ "function(...)" + end + end + + describe "statements" do + test "prints assignment" do + ast = chunk([assign([var("x")], [number(42)])]) + assert PrettyPrinter.print(ast) == "x = 42\n" + end + + test "prints multiple assignment" do + ast = chunk([assign([var("x"), var("y")], [number(1), number(2)])]) + assert PrettyPrinter.print(ast) == "x, y = 1, 2\n" + end + + test "prints local declaration" do + ast = chunk([local(["x"], [number(42)])]) + assert PrettyPrinter.print(ast) == "local x = 42\n" + end + + test "prints local declaration without value" do + ast = chunk([local(["x"], [])]) + assert PrettyPrinter.print(ast) == "local x\n" + end + + test "prints local function" do + ast = chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + result = PrettyPrinter.print(ast) + assert result =~ "local function add(a, b)" + assert result =~ "return a + b" + assert result =~ "end" + end + + test "prints function declaration" do + ast = chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + result = PrettyPrinter.print(ast) + assert result =~ "function add(a, b)" + assert result =~ "return a + b" + assert result =~ "end" + end + + test "prints return statement" do + ast = chunk([return_stmt([])]) + assert PrettyPrinter.print(ast) == "return\n" + + ast = chunk([return_stmt([number(42)])]) + assert PrettyPrinter.print(ast) == "return 42\n" + end + + test "prints break statement" do + ast = chunk([break_stmt()]) + assert PrettyPrinter.print(ast) == "break\n" + end + end + + describe "control flow" do + test "prints if statement" do + ast = chunk([ + if_stmt(var("x"), [return_stmt([number(1)])]) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "return 1" + assert result =~ "end" + end + + test "prints if-else statement" do + ast = chunk([ + if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "else" + assert result =~ "end" + end + + test "prints if-elseif-else statement" do + ast = chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "if x > 0 then" + assert result =~ "elseif x < 0 then" + assert result =~ "else" + assert result =~ "end" + end + + test "prints while loop" do + ast = chunk([ + while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "while x > 0 do" + assert result =~ "x = x - 1" + assert result =~ "end" + end + + test "prints repeat-until loop" do + ast = chunk([ + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "repeat" + assert result =~ "x = x - 1" + assert result =~ "until x <= 0" + end + + test "prints numeric for loop" do + ast = chunk([ + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "for i = 1, 10 do" + assert result =~ "print(i)" + assert result =~ "end" + end + + test "prints numeric for loop with step" do + ast = chunk([ + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ], step: number(2)) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "for i = 1, 10, 2 do" + end + + test "prints generic for loop" do + ast = chunk([ + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "for k, v in pairs(t) do" + assert result =~ "print(k, v)" + assert result =~ "end" + end + + test "prints do block" do + ast = chunk([ + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + ]) + result = PrettyPrinter.print(ast) + assert result =~ "do" + assert result =~ "local x = 10" + assert result =~ "print(x)" + assert result =~ "end" + end + end + + describe "indentation" do + test "indents nested blocks" do + ast = chunk([ + if_stmt(var("x"), [ + if_stmt(var("y"), [ + return_stmt([number(1)]) + ]) + ]) + ]) + result = PrettyPrinter.print(ast) + # Check that nested blocks are indented + lines = String.split(result, "\n", trim: true) + assert Enum.any?(lines, fn line -> String.starts_with?(line, " ") end) + end + + test "respects custom indent size" do + ast = chunk([ + if_stmt(var("x"), [ + return_stmt([number(1)]) + ]) + ]) + result = PrettyPrinter.print(ast, indent: 4) + assert result =~ " return 1" + end + end + + describe "round-trip" do + test "can round-trip simple expressions" do + original = "return 2 + 3\n" + {:ok, ast} = Lua.Parser.parse(original) + printed = PrettyPrinter.print(ast) + assert printed == original + end + + test "can round-trip local assignments" do + original = "local x = 42\n" + {:ok, ast} = Lua.Parser.parse(original) + printed = PrettyPrinter.print(ast) + assert printed == original + end + + test "can round-trip function declarations" do + code = """ + function add(a, b) + return a + b + end + """ + + {:ok, ast} = Lua.Parser.parse(code) + printed = PrettyPrinter.print(ast) + + # Parse again to verify structure matches + {:ok, ast2} = Lua.Parser.parse(printed) + + # Compare AST structures (ignoring meta) + assert ast.block.stmts |> length() == ast2.block.stmts |> length() + end + end +end diff --git a/test/lua/ast/walker_test.exs b/test/lua/ast/walker_test.exs new file mode 100644 index 0000000..e8b5429 --- /dev/null +++ b/test/lua/ast/walker_test.exs @@ -0,0 +1,293 @@ +defmodule Lua.AST.WalkerTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.{Walker, Expr, Stmt} + + describe "walk/2" do + test "visits all nodes in pre-order" do + # Build: local x = 2 + 3 + ast = chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + visited = [] + ref = :erlang.make_ref() + + Walker.walk(ast, fn node -> + send(self(), {ref, node}) + end) + + # Collect all visited nodes + visited = collect_messages(ref, []) + + # Should visit in pre-order: Chunk, Block, Local, BinOp, Number(2), Number(3) + assert length(visited) == 6 + assert hd(visited).__struct__ == Lua.AST.Chunk + end + + test "visits all nodes in post-order" do + # Build: local x = 2 + 3 + ast = chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + visited = [] + ref = :erlang.make_ref() + + Walker.walk(ast, fn node -> + send(self(), {ref, node}) + end, order: :post) + + visited = collect_messages(ref, []) + + # Should visit in post-order: children before parents + # Last visited should be Chunk + assert length(visited) == 6 + assert List.last(visited).__struct__ == Lua.AST.Chunk + end + + test "walks through if statement with all branches" do + # if x > 0 then return 1 elseif x < 0 then return -1 else return 0 end + ast = chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) + + count = count_nodes(ast) + # Chunk + Block + If + 3 conditions + 3 blocks + 3 return stmts + 3 values = many nodes + assert count > 10 + end + + test "walks through function expressions" do + # local f = function(a, b) return a + b end + ast = chunk([ + local(["f"], [function_expr(["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ])]) + ]) + + # Count variable references + var_count = Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + assert var_count == 2 # a and b + end + end + + describe "map/2" do + test "transforms number literals" do + # local x = 2 + 3 + ast = chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + # Double all numbers + transformed = Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + # Extract the numbers + numbers = Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [4, 6] # 2*2=4, 3*2=6 + end + + test "transforms variable names" do + # x = y + z + ast = chunk([ + assign([var("x")], [binop(:add, var("y"), var("z"))]) + ]) + + # Add prefix to all variables + transformed = Walker.map(ast, fn + %Expr.Var{name: name} = node -> %{node | name: "local_" <> name} + node -> node + end) + + # Collect variable names + names = Walker.reduce(transformed, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(names) == ["local_x", "local_y", "local_z"] + end + + test "preserves structure while transforming" do + # if true then print(1) end + ast = chunk([ + if_stmt(bool(true), [ + call_stmt(call(var("print"), [number(1)])) + ]) + ]) + + # Transform should preserve structure + transformed = Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 1} + node -> node + end) + + # Extract the if statement + [%Stmt.If{condition: %Expr.Bool{value: true}}] = transformed.block.stmts + + # Number should be transformed + numbers = Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [2] # 1 + 1 = 2 + end + end + + describe "reduce/3" do + test "counts all nodes" do + # local x = 1; local y = 2; return x + y + ast = chunk([ + local(["x"], [number(1)]), + local(["y"], [number(2)]), + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + + count = Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) + + # Should count all nodes + assert count > 5 + end + + test "collects specific node types" do + # local x = 1; y = 2; print(x, y) + ast = chunk([ + local(["x"], [number(1)]), + assign([var("y")], [number(2)]), + call_stmt(call(var("print"), [var("x"), var("y")])) + ]) + + # Collect all variable names + vars = Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["print", "x", "y", "y"] + + # Collect all numbers + numbers = Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [1, 2] + end + + test "builds maps from nodes" do + # local x = 10; local y = 20 + ast = chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]) + ]) + + # Build map of local declarations: name -> value + locals = Walker.reduce(ast, %{}, fn + %Stmt.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> + Map.put(acc, name, n) + _, acc -> + acc + end) + + assert locals == %{"x" => 10, "y" => 20} + end + + test "accumulates deeply nested values" do + # function f() return function() return 42 end end + ast = chunk([ + func_decl("f", [], [ + return_stmt([function_expr([], [ + return_stmt([number(42)]) + ])]) + ]) + ]) + + # Count function expressions + func_count = Walker.reduce(ast, 0, fn + %Expr.Function{}, acc -> acc + 1 + _, acc -> acc + end) + + assert func_count == 1 + end + end + + describe "complex AST traversal" do + test "handles nested loops and conditions" do + # for i = 1, 10 do + # if i % 2 == 0 then + # print(i) + # end + # end + ast = chunk([ + for_num("i", number(1), number(10), [ + if_stmt( + binop(:eq, binop(:mod, var("i"), number(2)), number(0)), + [call_stmt(call(var("print"), [var("i")]))] + ) + ]) + ]) + + # Count all operators + ops = Walker.reduce(ast, [], fn + %Expr.BinOp{op: op}, acc -> [op | acc] + _, acc -> acc + end) + + assert :eq in ops + assert :mod in ops + end + + test "handles table constructors" do + # local t = {x = 1, y = 2, [3] = "three"} + ast = chunk([ + local(["t"], [ + table([ + {:record, string("x"), number(1)}, + {:record, string("y"), number(2)}, + {:record, number(3), string("three")} + ]) + ]) + ]) + + # Count table fields + field_count = Walker.reduce(ast, 0, fn + %Expr.Table{fields: fields}, acc -> acc + length(fields) + _, acc -> acc + end) + + assert field_count == 3 + end + end + + # Helper to count nodes + defp count_nodes(ast) do + Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) + end + + # Helper to collect messages + defp collect_messages(ref, acc) do + receive do + {^ref, node} -> collect_messages(ref, [node | acc]) + after + 0 -> Enum.reverse(acc) + end + end +end diff --git a/test/lua/lexer_test.exs b/test/lua/lexer_test.exs new file mode 100644 index 0000000..e59c107 --- /dev/null +++ b/test/lua/lexer_test.exs @@ -0,0 +1,482 @@ +defmodule Lua.LexerTest do + use ExUnit.Case, async: true + alias Lua.Lexer + + doctest Lua.Lexer + + describe "keywords" do + test "tokenizes all Lua keywords" do + keywords = [ + :and, + :break, + :do, + :else, + :elseif, + :end, + :false, + :for, + :function, + :goto, + :if, + :in, + :local, + :nil, + :not, + :or, + :repeat, + :return, + :then, + :true, + :until, + :while + ] + + for keyword <- keywords do + keyword_str = Atom.to_string(keyword) + assert {:ok, tokens} = Lexer.tokenize(keyword_str) + assert [{:keyword, ^keyword, _}, {:eof, _}] = tokens + end + end + + test "keywords are case-sensitive" do + assert {:ok, [{:identifier, "IF", _}, {:eof, _}]} = Lexer.tokenize("IF") + assert {:ok, [{:identifier, "End", _}, {:eof, _}]} = Lexer.tokenize("End") + end + end + + describe "identifiers" do + test "tokenizes simple identifiers" do + assert {:ok, [{:identifier, "foo", _}, {:eof, _}]} = Lexer.tokenize("foo") + assert {:ok, [{:identifier, "bar123", _}, {:eof, _}]} = Lexer.tokenize("bar123") + assert {:ok, [{:identifier, "_test", _}, {:eof, _}]} = Lexer.tokenize("_test") + assert {:ok, [{:identifier, "CamelCase", _}, {:eof, _}]} = Lexer.tokenize("CamelCase") + end + + test "identifiers can start with underscore" do + assert {:ok, [{:identifier, "_", _}, {:eof, _}]} = Lexer.tokenize("_") + assert {:ok, [{:identifier, "__private", _}, {:eof, _}]} = Lexer.tokenize("__private") + end + + test "identifiers can contain numbers but not start with them" do + assert {:ok, [{:identifier, "var1", _}, {:eof, _}]} = Lexer.tokenize("var1") + assert {:ok, [{:identifier, "test123abc", _}, {:eof, _}]} = Lexer.tokenize("test123abc") + end + end + + describe "numbers" do + test "tokenizes integers" do + assert {:ok, [{:number, 0, _}, {:eof, _}]} = Lexer.tokenize("0") + assert {:ok, [{:number, 42, _}, {:eof, _}]} = Lexer.tokenize("42") + assert {:ok, [{:number, 12345, _}, {:eof, _}]} = Lexer.tokenize("12345") + end + + test "tokenizes floating point numbers" do + assert {:ok, [{:number, 3.14, _}, {:eof, _}]} = Lexer.tokenize("3.14") + assert {:ok, [{:number, 0.5, _}, {:eof, _}]} = Lexer.tokenize("0.5") + assert {:ok, [{:number, 10.0, _}, {:eof, _}]} = Lexer.tokenize("10.0") + end + + test "tokenizes hexadecimal numbers" do + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0xFF") + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0xff") + assert {:ok, [{:number, 0, _}, {:eof, _}]} = Lexer.tokenize("0x0") + assert {:ok, [{:number, 4095, _}, {:eof, _}]} = Lexer.tokenize("0xfff") + end + + test "tokenizes scientific notation" do + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1e10") + assert num == 1.0e10 + + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1.5e-5") + assert num == 1.5e-5 + + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("3E+2") + assert num == 3.0e2 + end + + test "handles trailing dot correctly" do + # "42." should be tokenized as number 42 followed by dot operator + # But in Lua, "42." is actually a valid number (42.0) + # Let's test both interpretations + assert {:ok, tokens} = Lexer.tokenize("42.") + # This might be [{:number, 42.0}, {:eof}] or [{:number, 42}, {:delimiter, :dot}, {:eof}] + # depending on implementation + assert length(tokens) >= 2 + end + end + + describe "strings" do + test "tokenizes double-quoted strings" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize(~s("hello")) + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize(~s("")) + assert {:ok, [{:string, "hello world", _}, {:eof, _}]} = Lexer.tokenize(~s("hello world")) + end + + test "tokenizes single-quoted strings" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("'hello'") + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize("''") + assert {:ok, [{:string, "hello world", _}, {:eof, _}]} = Lexer.tokenize("'hello world'") + end + + test "handles escape sequences in strings" do + assert {:ok, [{:string, "hello\nworld", _}, {:eof, _}]} = Lexer.tokenize(~s("hello\\nworld")) + assert {:ok, [{:string, "tab\there", _}, {:eof, _}]} = Lexer.tokenize(~s("tab\\there")) + + assert {:ok, [{:string, "quote\"here", _}, {:eof, _}]} = + Lexer.tokenize(~s("quote\\"here")) + + assert {:ok, [{:string, "backslash\\here", _}, {:eof, _}]} = + Lexer.tokenize(~s("backslash\\\\here")) + end + + test "tokenizes long strings with [[...]]" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("[[hello]]") + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize("[[]]") + + assert {:ok, [{:string, "multi\nline", _}, {:eof, _}]} = + Lexer.tokenize("[[multi\nline]]") + end + + test "tokenizes long strings with equals signs [=[...]=]" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("[=[hello]=]") + assert {:ok, [{:string, "test", _}, {:eof, _}]} = Lexer.tokenize("[==[test]==]") + assert {:ok, [{:string, "a]b", _}, {:eof, _}]} = Lexer.tokenize("[=[a]b]=]") + end + + test "reports error for unclosed string" do + assert {:error, {:unclosed_string, _}} = Lexer.tokenize(~s("hello)) + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("'hello") + end + + test "reports error for unclosed long string" do + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[[hello") + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[=[test") + end + end + + describe "operators" do + test "tokenizes single-character operators" do + assert {:ok, [{:operator, :add, _}, {:eof, _}]} = Lexer.tokenize("+") + assert {:ok, [{:operator, :sub, _}, {:eof, _}]} = Lexer.tokenize("-") + assert {:ok, [{:operator, :mul, _}, {:eof, _}]} = Lexer.tokenize("*") + assert {:ok, [{:operator, :div, _}, {:eof, _}]} = Lexer.tokenize("/") + assert {:ok, [{:operator, :mod, _}, {:eof, _}]} = Lexer.tokenize("%") + assert {:ok, [{:operator, :pow, _}, {:eof, _}]} = Lexer.tokenize("^") + assert {:ok, [{:operator, :len, _}, {:eof, _}]} = Lexer.tokenize("#") + assert {:ok, [{:operator, :lt, _}, {:eof, _}]} = Lexer.tokenize("<") + assert {:ok, [{:operator, :gt, _}, {:eof, _}]} = Lexer.tokenize(">") + assert {:ok, [{:operator, :assign, _}, {:eof, _}]} = Lexer.tokenize("=") + end + + test "tokenizes two-character operators" do + assert {:ok, [{:operator, :eq, _}, {:eof, _}]} = Lexer.tokenize("==") + assert {:ok, [{:operator, :ne, _}, {:eof, _}]} = Lexer.tokenize("~=") + assert {:ok, [{:operator, :le, _}, {:eof, _}]} = Lexer.tokenize("<=") + assert {:ok, [{:operator, :ge, _}, {:eof, _}]} = Lexer.tokenize(">=") + assert {:ok, [{:operator, :concat, _}, {:eof, _}]} = Lexer.tokenize("..") + assert {:ok, [{:operator, :floordiv, _}, {:eof, _}]} = Lexer.tokenize("//") + end + + test "tokenizes three-character operators" do + assert {:ok, [{:operator, :vararg, _}, {:eof, _}]} = Lexer.tokenize("...") + end + + test "distinguishes between . and .." do + assert {:ok, [{:delimiter, :dot, _}, {:eof, _}]} = Lexer.tokenize(".") + assert {:ok, [{:operator, :concat, _}, {:eof, _}]} = Lexer.tokenize("..") + assert {:ok, [{:operator, :vararg, _}, {:eof, _}]} = Lexer.tokenize("...") + end + end + + describe "delimiters" do + test "tokenizes parentheses" do + assert {:ok, [{:delimiter, :lparen, _}, {:eof, _}]} = Lexer.tokenize("(") + assert {:ok, [{:delimiter, :rparen, _}, {:eof, _}]} = Lexer.tokenize(")") + end + + test "tokenizes braces" do + assert {:ok, [{:delimiter, :lbrace, _}, {:eof, _}]} = Lexer.tokenize("{") + assert {:ok, [{:delimiter, :rbrace, _}, {:eof, _}]} = Lexer.tokenize("}") + end + + test "tokenizes brackets" do + assert {:ok, [{:delimiter, :lbracket, _}, {:eof, _}]} = Lexer.tokenize("[") + end + + test "tokenizes other delimiters" do + assert {:ok, [{:delimiter, :semicolon, _}, {:eof, _}]} = Lexer.tokenize(";") + assert {:ok, [{:delimiter, :comma, _}, {:eof, _}]} = Lexer.tokenize(",") + assert {:ok, [{:delimiter, :colon, _}, {:eof, _}]} = Lexer.tokenize(":") + assert {:ok, [{:delimiter, :double_colon, _}, {:eof, _}]} = Lexer.tokenize("::") + end + end + + describe "comments" do + test "skips single-line comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- this is a comment") + + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("x -- comment after code") + end + + test "skips multi-line comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ this is a\nmulti-line comment ]]") + + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("x --[[ comment ]] ") + end + + test "skips multi-line comments with equals signs" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ comment ]=]") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[==[ comment ]==]") + end + + test "handles content with brackets in multi-line comments" do + # The first ]] closes the comment regardless of internal [[ + # So this closes at the first ]] and leaves " inside ]]" as code + assert {:ok, tokens} = Lexer.tokenize("--[[ comment ]]") + assert [{:eof, _}] = tokens + + # With nesting levels using =, you can include ]] in the comment + assert {:ok, tokens2} = Lexer.tokenize("--[=[ comment with ]] in it ]=]") + assert [{:eof, _}] = tokens2 + end + + test "reports error for unclosed multi-line comment" do + assert {:error, {:unclosed_comment, _}} = Lexer.tokenize("--[[ unclosed comment") + end + end + + describe "whitespace" do + test "skips spaces and tabs" do + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1 2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\t2") + end + + test "handles newlines" do + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\n2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r\n2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r2") + end + end + + describe "position tracking" do + test "tracks line and column for single line" do + assert {:ok, tokens} = Lexer.tokenize("local x = 42") + + assert [ + {:keyword, :local, %{line: 1, column: 1, byte_offset: 0}}, + {:identifier, "x", %{line: 1, column: 7, byte_offset: 6}}, + {:operator, :assign, %{line: 1, column: 9, byte_offset: 8}}, + {:number, 42, %{line: 1, column: 11, byte_offset: 10}}, + {:eof, _} + ] = tokens + end + + test "tracks line numbers across multiple lines" do + code = """ + local x + x = 42 + """ + + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :local, %{line: 1}}, + {:identifier, "x", %{line: 1}}, + {:identifier, "x", %{line: 2}}, + {:operator, :assign, %{line: 2}}, + {:number, 42, %{line: 2}}, + {:eof, _} + ] = tokens + end + + test "tracks position in strings" do + assert {:ok, tokens} = Lexer.tokenize(~s("hello")) + + assert [{:string, "hello", %{line: 1, column: 1, byte_offset: 0}}, {:eof, _}] = tokens + end + end + + describe "complex expressions" do + test "tokenizes arithmetic expression" do + assert {:ok, tokens} = Lexer.tokenize("2 + 3 * 4") + + assert [ + {:number, 2, _}, + {:operator, :add, _}, + {:number, 3, _}, + {:operator, :mul, _}, + {:number, 4, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes function call" do + assert {:ok, tokens} = Lexer.tokenize("print(42)") + + assert [ + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:number, 42, _}, + {:delimiter, :rparen, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes table constructor" do + assert {:ok, tokens} = Lexer.tokenize("{a = 1, b = 2}") + + assert [ + {:delimiter, :lbrace, _}, + {:identifier, "a", _}, + {:operator, :assign, _}, + {:number, 1, _}, + {:delimiter, :comma, _}, + {:identifier, "b", _}, + {:operator, :assign, _}, + {:number, 2, _}, + {:delimiter, :rbrace, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes method call" do + assert {:ok, tokens} = Lexer.tokenize("obj:method()") + + assert [ + {:identifier, "obj", _}, + {:delimiter, :colon, _}, + {:identifier, "method", _}, + {:delimiter, :lparen, _}, + {:delimiter, :rparen, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes vararg" do + assert {:ok, tokens} = Lexer.tokenize("function f(...) return ... end") + + assert [ + {:keyword, :function, _}, + {:identifier, "f", _}, + {:delimiter, :lparen, _}, + {:operator, :vararg, _}, + {:delimiter, :rparen, _}, + {:keyword, :return, _}, + {:operator, :vararg, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + end + + describe "edge cases" do + test "empty input" do + assert {:ok, [{:eof, %{line: 1, column: 1, byte_offset: 0}}]} = Lexer.tokenize("") + end + + test "only whitespace" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize(" \n \t ") + end + + test "only comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- just a comment") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ just a comment ]]") + end + + test "reports error for unexpected character" do + assert {:error, {:unexpected_character, ?@, _}} = Lexer.tokenize("@") + assert {:error, {:unexpected_character, ?$, _}} = Lexer.tokenize("$") + assert {:error, {:unexpected_character, ?`, _}} = Lexer.tokenize("`") + end + + test "handles consecutive operators" do + assert {:ok, tokens} = Lexer.tokenize("+-*/") + + assert [ + {:operator, :add, _}, + {:operator, :sub, _}, + {:operator, :mul, _}, + {:operator, :div, _}, + {:eof, _} + ] = tokens + end + + test "distinguishes >= from > =" do + assert {:ok, [{:operator, :ge, _}, {:eof, _}]} = Lexer.tokenize(">=") + + assert {:ok, [{:operator, :gt, _}, {:operator, :assign, _}, {:eof, _}]} = + Lexer.tokenize("> =") + end + end + + describe "real Lua code examples" do + test "tokenizes variable assignment" do + code = "local x = 42" + assert {:ok, tokens} = Lexer.tokenize(code) + assert length(tokens) == 5 + end + + test "tokenizes if statement" do + code = "if x > 0 then print(x) end" + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :if, _}, + {:identifier, "x", _}, + {:operator, :gt, _}, + {:number, 0, _}, + {:keyword, :then, _}, + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:identifier, "x", _}, + {:delimiter, :rparen, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes function definition" do + code = """ + function add(a, b) + return a + b + end + """ + + assert {:ok, tokens} = Lexer.tokenize(code) + + assert Enum.any?(tokens, fn + {:keyword, :function, _} -> true + _ -> false + end) + end + + test "tokenizes for loop" do + code = "for i = 1, 10 do print(i) end" + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :for, _}, + {:identifier, "i", _}, + {:operator, :assign, _}, + {:number, 1, _}, + {:delimiter, :comma, _}, + {:number, 10, _}, + {:keyword, :do, _}, + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:identifier, "i", _}, + {:delimiter, :rparen, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes table with mixed fields" do + code = "{1, 2, x = 3, [\"key\"] = 4}" + assert {:ok, tokens} = Lexer.tokenize(code) + assert length(tokens) > 10 + end + end +end diff --git a/test/lua/parser/beautiful_errors_test.exs b/test/lua/parser/beautiful_errors_test.exs new file mode 100644 index 0000000..02bbabb --- /dev/null +++ b/test/lua/parser/beautiful_errors_test.exs @@ -0,0 +1,269 @@ +defmodule Lua.Parser.BeautifulErrorsTest do + use ExUnit.Case, async: true + alias Lua.Parser + + @moduletag :beautiful_errors + + describe "beautiful error message demonstrations" do + test "missing 'end' keyword shows context and suggestion" do + code = """ + function factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + -- Missing 'end' here! + """ + + assert {:error, msg} = Parser.parse(code) + + # Check for essential components + assert msg =~ ~r/Parse Error/i + assert msg =~ "line" + assert msg =~ "Expected" + assert msg =~ "end" + + # Check for visual formatting + assert msg =~ "│" # Line separator + assert msg =~ "^" # Error pointer + + # Should have ANSI color codes + assert msg =~ "\e[" + + # Print for manual inspection during test runs + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 1: Missing 'end' keyword") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "missing 'then' keyword provides helpful suggestion" do + code = """ + if x > 0 + print(x) + end + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + assert msg =~ "Expected" + assert msg =~ ":then" + assert msg =~ "line 2" + + # Should show the problematic line + assert msg =~ "print(x)" + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 2: Missing 'then' keyword") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "unclosed string shows line with error pointer" do + code = """ + local message = "Hello, World! + print(message) + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + assert msg =~ "Unclosed string" + assert msg =~ "line 1" + + # Should show suggestion + assert msg =~ "Suggestion" + assert msg =~ "closing quote" + + # Should show the unclosed string line + assert msg =~ ~s(local message = "Hello, World!) + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 3: Unclosed string") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "missing closing parenthesis shows context" do + code = """ + local function test(a, b + return a + b + end + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + assert msg =~ "Expected" + assert msg =~ ":rparen" + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 4: Missing closing parenthesis") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "invalid character shows clear message" do + code = """ + local x = 42 + local y = @invalid + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + assert msg =~ "Unexpected character" + assert msg =~ "line 2" + assert msg =~ "@" + + # Should have suggestion + assert msg =~ "Suggestion" + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 5: Invalid character") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "missing 'do' in while loop" do + code = """ + while x > 0 + x = x - 1 + end + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + assert msg =~ "Expected" + assert msg =~ ":do" + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 6: Missing 'do' in while loop") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "complex error with multiple context lines" do + code = """ + function complex_function() + local x = 10 + local y = 20 + if x > y then + return x + -- Missing 'end' for if + return y + -- Missing 'end' for function + """ + + assert {:error, msg} = Parser.parse(code) + + assert msg =~ "Parse Error" + + # Error is at EOF (line 9), so context shows lines around line 9 + # Should show the lines that are actually in the context window (lines 7-9) + assert msg =~ "return y" + assert msg =~ "-- Missing 'end' for function" + + if System.get_env("SHOW_ERRORS") do + IO.puts("\n" <> String.duplicate("=", 70)) + IO.puts("Example 7: Complex error with context") + IO.puts(String.duplicate("=", 70)) + IO.puts(msg) + IO.puts(String.duplicate("=", 70) <> "\n") + end + end + + test "error message formatting has proper structure" do + code = "if x then" + + assert {:error, msg} = Parser.parse(code) + + # Check structure components + assert msg =~ "Parse Error" + assert msg =~ "at line" + assert msg =~ "column" + + # Check visual elements + assert msg =~ "│" # Box drawing character for line separator + assert msg =~ "^" # Pointer to error location + + # Check color codes (ANSI escape sequences) + assert String.contains?(msg, "\e[31m") # Red color for error + assert String.contains?(msg, "\e[0m") # Reset color + end + end + + describe "error message quality checks" do + test "always includes line and column information when available" do + code = """ + local x = 1 + if x > 0 then + print(x + end + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ ~r/line \d+/ + assert msg =~ ~r/column \d+/ + end + + test "always includes visual pointer to error location" do + code = "local x = +" + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "^" # Caret pointer + end + + test "shows surrounding context lines" do + code = """ + line1 = 1 + line2 = 2 + if x then + line4 = 4 + line5 = 5 + """ + + # This will fail to parse due to missing 'end' + {:error, msg} = Parser.parse(code) + + # Should show lines around the error with box drawing + assert msg =~ "│" # Line separator + assert msg =~ "line" # Should show context lines + end + + test "uses colors for better readability" do + code = "if x then" + + assert {:error, msg} = Parser.parse(code) + + # Red for errors + assert msg =~ "\e[31m" + # Bright/bold + assert msg =~ "\e[1m" + # Reset + assert msg =~ "\e[0m" + # Cyan for suggestions + assert msg =~ "\e[36m" + end + end +end diff --git a/test/lua/parser/error_test.exs b/test/lua/parser/error_test.exs new file mode 100644 index 0000000..15f268d --- /dev/null +++ b/test/lua/parser/error_test.exs @@ -0,0 +1,172 @@ +defmodule Lua.Parser.ErrorTest do + use ExUnit.Case, async: true + alias Lua.Parser + + describe "beautiful error messages" do + test "missing 'end' keyword shows helpful message" do + code = """ + function foo() + return 1 + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Parse Error") + assert String.contains?(error_msg, "line 3") + assert String.contains?(error_msg, "Expected") + assert String.contains?(error_msg, "'end'") + end + + test "missing 'then' keyword provides suggestion" do + code = """ + if x > 0 + return x + end + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Expected") + assert String.contains?(error_msg, ":then") + end + + test "missing 'do' keyword in while loop" do + code = """ + while x > 0 + x = x - 1 + end + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Expected") + assert String.contains?(error_msg, ":do") + end + + test "unclosed string shows context" do + code = """ + local x = "hello + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Unclosed string") + assert String.contains?(error_msg, "line 1") + end + + test "unexpected character shows position" do + code = """ + local x = 42 + local y = @invalid + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Unexpected character") + assert String.contains?(error_msg, "line 2") + end + + test "missing closing parenthesis" do + code = """ + print(1, 2, 3 + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Parse Error") + # Should mention parenthesis or bracket + end + + test "missing closing bracket" do + code = """ + local t = {1, 2, 3 + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Parse Error") + end + + test "shows context with line numbers" do + code = """ + local x = 1 + local y = 2 + if x > y + print(x) + end + """ + + assert {:error, error_msg} = Parser.parse(code) + # Should show context around the error + assert String.contains?(error_msg, "│") + end + + test "unexpected token in expression" do + code = """ + local x = 1 + + 2 + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Parse Error") + end + + test "invalid syntax after valid code" do + code = """ + function add(a, b) + return a + b + end + + function multiply(x, y + return x * y + end + """ + + assert {:error, error_msg} = Parser.parse(code) + # Error is on line 6 (the return statement) because line 5 is missing closing ) + assert String.contains?(error_msg, "line 6") + end + end + + describe "error message formatting" do + test "formats with color codes for terminal" do + code = "if x then" + + assert {:error, error_msg} = Parser.parse(code) + # Color codes should be present (ANSI escape codes) + assert String.contains?(error_msg, "\e[") + end + + test "shows helpful suggestions" do + code = """ + function test() + print("hello") + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "Suggestion") + end + + test "includes line and column information" do + code = """ + local x = 1 + if x > 0 then + print(x + end + """ + + assert {:error, error_msg} = Parser.parse(code) + assert String.contains?(error_msg, "line") + assert String.contains?(error_msg, "column") + end + end + + describe "raw error parsing" do + test "parse_raw returns structured error" do + code = "if x then" + + assert {:error, error_tuple} = Parser.parse_raw(code) + # Should be a tuple, not a formatted string + assert is_tuple(error_tuple) + end + + test "parse_raw successful parsing" do + code = "local x = 42" + + assert {:ok, chunk} = Parser.parse_raw(code) + assert chunk.__struct__ == Lua.AST.Chunk + end + end +end diff --git a/test/lua/parser/expr_test.exs b/test/lua/parser/expr_test.exs new file mode 100644 index 0000000..02cec9d --- /dev/null +++ b/test/lua/parser/expr_test.exs @@ -0,0 +1,85 @@ +defmodule Lua.Parser.ExprTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.{Expr, Stmt} + + # Helper to extract the returned expression from "return expr" + defp parse_return_expr(code) do + case Parser.parse(code) do + {:ok, %{block: %{stmts: [%Stmt.Return{values: [expr]}]}}} -> + {:ok, expr} + + {:ok, %{block: %{stmts: [%Stmt.Return{values: exprs}]}}} -> + {:ok, exprs} + + other -> + other + end + end + + describe "basic parsing" do + test "parses simple expressions" do + assert {:ok, %Expr.Number{value: 42}} = parse_return_expr("return 42") + assert {:ok, %Expr.Bool{value: true}} = parse_return_expr("return true") + assert {:ok, %Expr.String{value: "hello"}} = parse_return_expr(~s(return "hello")) + assert {:ok, %Expr.Nil{}} = parse_return_expr("return nil") + assert {:ok, %Expr.Var{name: "x"}} = parse_return_expr("return x") + end + + test "parses binary operations" do + assert {:ok, %Expr.BinOp{op: :add}} = parse_return_expr("return 1 + 2") + assert {:ok, %Expr.BinOp{op: :mul}} = parse_return_expr("return 2 * 3") + assert {:ok, %Expr.BinOp{op: :concat}} = parse_return_expr(~s(return "a" .. "b")) + end + + test "parses unary operations" do + assert {:ok, %Expr.UnOp{op: :not}} = parse_return_expr("return not true") + assert {:ok, %Expr.UnOp{op: :neg}} = parse_return_expr("return -5") + assert {:ok, %Expr.UnOp{op: :len}} = parse_return_expr("return #t") + end + + test "parses table constructors" do + assert {:ok, %Expr.Table{fields: []}} = parse_return_expr("return {}") + assert {:ok, %Expr.Table{fields: [_, _, _]}} = parse_return_expr("return {1, 2, 3}") + assert {:ok, %Expr.Table{}} = parse_return_expr("return {a = 1, b = 2}") + end + + test "parses function expressions" do + assert {:ok, %Expr.Function{params: []}} = + parse_return_expr("return function() end") + + assert {:ok, %Expr.Function{params: ["a", "b"]}} = + parse_return_expr("return function(a, b) end") + + assert {:ok, %Expr.Function{params: ["a", :vararg]}} = + parse_return_expr("return function(a, ...) end") + end + + test "parses function calls" do + assert {:ok, %Expr.Call{func: %Expr.Var{name: "f"}, args: []}} = + parse_return_expr("return f()") + + assert {:ok, %Expr.Call{args: [_, _, _]}} = parse_return_expr("return f(1, 2, 3)") + end + + test "parses property access and indexing" do + assert {:ok, %Expr.Property{table: %Expr.Var{name: "t"}, field: "field"}} = + parse_return_expr("return t.field") + + assert {:ok, %Expr.Index{table: %Expr.Var{name: "t"}}} = + parse_return_expr("return t[1]") + end + + test "parses method calls" do + assert {:ok, %Expr.MethodCall{object: %Expr.Var{name: "obj"}, method: "method"}} = + parse_return_expr("return obj:method()") + end + + test "parses complex nested expressions" do + assert {:ok, _} = parse_return_expr("return 1 + 2 * 3") + assert {:ok, _} = parse_return_expr("return (1 + 2) * 3") + assert {:ok, _} = parse_return_expr("return f(g(h(x)))") + assert {:ok, _} = parse_return_expr("return t.a.b.c") + end + end +end diff --git a/test/lua/parser/precedence_test.exs b/test/lua/parser/precedence_test.exs new file mode 100644 index 0000000..08ddf37 --- /dev/null +++ b/test/lua/parser/precedence_test.exs @@ -0,0 +1,429 @@ +defmodule Lua.Parser.PrecedenceTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.Expr + + describe "operator precedence" do + test "or has lowest precedence" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a and b or c and d") + + # Should parse as: (a and b) or (c and d) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.BinOp{op: :and}, + right: %Expr.BinOp{op: :and} + } + ] + } = stmt + end + + test "and has higher precedence than or" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b and c") + + # Should parse as: a or (b and c) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{op: :and} + } + ] + } = stmt + end + + test "comparison has higher precedence than logical" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a < b and c > d") + + # Should parse as: (a < b) and (c > d) + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :lt}, + right: %Expr.BinOp{op: :gt} + } + ] + } = stmt + end + + test "concatenation has higher precedence than comparison" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse(~s(return "a" .. "b" < "c" .. "d")) + + # Should parse as: ("a" .. "b") < ("c" .. "d") + assert %{ + values: [ + %Expr.BinOp{ + op: :lt, + left: %Expr.BinOp{op: :concat}, + right: %Expr.BinOp{op: :concat} + } + ] + } = stmt + end + + test "addition has higher precedence than concatenation" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a + b .. c + d") + + # Should parse as: (a + b) .. (c + d) + assert %{ + values: [ + %Expr.BinOp{ + op: :concat, + left: %Expr.BinOp{op: :add}, + right: %Expr.BinOp{op: :add} + } + ] + } = stmt + end + + test "multiplication has higher precedence than addition" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a + b * c") + + # Should parse as: a + (b * c) + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{op: :mul} + } + ] + } = stmt + end + + test "unary has higher precedence than multiplication" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return -a * b") + + # Should parse as: (-a) * b + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.UnOp{op: :neg}, + right: %Expr.Var{name: "b"} + } + ] + } = stmt + end + + test "power has higher precedence than unary (special case)" do + # In Lua, -2^3 = -(2^3) = -8, not (-2)^3 + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return -2 ^ 3") + + # Should parse as: -(2 ^ 3) + assert %{ + values: [ + %Expr.UnOp{ + op: :neg, + operand: %Expr.BinOp{op: :pow} + } + ] + } = stmt + end + + test "not has higher precedence than and" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return not a and b") + + # Should parse as: (not a) and b + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.UnOp{op: :not}, + right: %Expr.Var{name: "b"} + } + ] + } = stmt + end + + test "length operator has same precedence as unary minus" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return #t + 1") + + # Should parse as: (#t) + 1 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.UnOp{op: :len}, + right: %Expr.Number{value: 1} + } + ] + } = stmt + end + end + + describe "associativity" do + test "addition is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 + 2 + 3") + + # Should parse as: (1 + 2) + 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.BinOp{ + op: :add, + left: %Expr.Number{value: 1}, + right: %Expr.Number{value: 2} + }, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + + test "subtraction is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 10 - 5 - 2") + + # Should parse as: (10 - 5) - 2 = 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :sub, + left: %Expr.BinOp{op: :sub}, + right: %Expr.Number{value: 2} + } + ] + } = stmt + end + + test "multiplication is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 2 * 3 * 4") + + # Should parse as: (2 * 3) * 4 + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :mul}, + right: %Expr.Number{value: 4} + } + ] + } = stmt + end + + test "division is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 24 / 4 / 2") + + # Should parse as: (24 / 4) / 2 = 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :div, + left: %Expr.BinOp{op: :div}, + right: %Expr.Number{value: 2} + } + ] + } = stmt + end + + test "power is right associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 2 ^ 3 ^ 2") + + # Should parse as: 2 ^ (3 ^ 2) = 2 ^ 9 = 512 + assert %{ + values: [ + %Expr.BinOp{ + op: :pow, + left: %Expr.Number{value: 2}, + right: %Expr.BinOp{ + op: :pow, + left: %Expr.Number{value: 3}, + right: %Expr.Number{value: 2} + } + } + ] + } = stmt + end + + test "concatenation is right associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse(~s(return "a" .. "b" .. "c")) + + # Should parse as: "a" .. ("b" .. "c") + assert %{ + values: [ + %Expr.BinOp{ + op: :concat, + left: %Expr.String{value: "a"}, + right: %Expr.BinOp{ + op: :concat, + left: %Expr.String{value: "b"}, + right: %Expr.String{value: "c"} + } + } + ] + } = stmt + end + + test "and is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a and b and c") + + # Should parse as: (a and b) and c + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :and}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "or is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b or c") + + # Should parse as: (a or b) or c + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.BinOp{op: :or}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "comparison is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 < 2 < 3") + + # Should parse as: (1 < 2) < 3 + # Note: This is legal in Lua but semantically weird (compares boolean with number) + assert %{ + values: [ + %Expr.BinOp{ + op: :lt, + left: %Expr.BinOp{op: :lt}, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + end + + describe "complex precedence cases" do + test "all operators with correct precedence" do + # a or b and c < d .. e + f * g ^ h + # Should parse as: a or (b and (c < (d .. (e + (f * (g ^ h)))))) + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b and c < d .. e + f * g ^ h") + + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{ + op: :and, + left: %Expr.Var{name: "b"}, + right: %Expr.BinOp{ + op: :lt, + left: %Expr.Var{name: "c"}, + right: %Expr.BinOp{ + op: :concat, + left: %Expr.Var{name: "d"}, + right: %Expr.BinOp{ + op: :add, + left: %Expr.Var{name: "e"}, + right: %Expr.BinOp{ + op: :mul, + left: %Expr.Var{name: "f"}, + right: %Expr.BinOp{ + op: :pow, + left: %Expr.Var{name: "g"}, + right: %Expr.Var{name: "h"} + } + } + } + } + } + } + } + ] + } = stmt + end + + test "unary operators with various binary operators" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return not a or -b and #c") + + # Should parse as: (not a) or ((-b) and (#c)) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.UnOp{op: :not}, + right: %Expr.BinOp{ + op: :and, + left: %Expr.UnOp{op: :neg}, + right: %Expr.UnOp{op: :len} + } + } + ] + } = stmt + end + + test "mixed arithmetic with different precedences" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 + 2 * 3 - 4 / 5 % 6") + + # Should parse as: (1 + (2 * 3)) - ((4 / 5) % 6) + # With left associativity: ((1 + (2 * 3)) - ((4 / 5) % 6)) + # *, /, % have same precedence, so (4 / 5) % 6 parses left-to-right + # +, - have same precedence, so the top level is: (...) - (...) + assert %{values: [%Expr.BinOp{op: :sub}]} = stmt + end + end + + describe "parentheses override precedence" do + test "parentheses around addition before multiplication" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return (1 + 2) * 3") + + # Should parse as: (1 + 2) * 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + + test "parentheses around or before and" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return (a or b) and c") + + # Should parse as: (a or b) and c + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :or}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "nested parentheses" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return ((1 + 2) * 3) + 4") + + # Should parse as: ((1 + 2) * 3) + 4 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 3} + }, + right: %Expr.Number{value: 4} + } + ] + } = stmt + end + end +end diff --git a/test/lua/parser/stmt_test.exs b/test/lua/parser/stmt_test.exs new file mode 100644 index 0000000..5697f88 --- /dev/null +++ b/test/lua/parser/stmt_test.exs @@ -0,0 +1,591 @@ +defmodule Lua.Parser.StmtTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.{Stmt, Expr} + + describe "local variable declarations" do + test "parses local without initialization" do + assert {:ok, chunk} = Parser.parse("local x") + assert %{block: %{stmts: [%Stmt.Local{names: ["x"], values: []}]}} = chunk + end + + test "parses local with single initialization" do + assert {:ok, chunk} = Parser.parse("local x = 42") + + assert %{ + block: %{ + stmts: [%Stmt.Local{names: ["x"], values: [%Expr.Number{value: 42}]}] + } + } = chunk + end + + test "parses local with multiple variables" do + assert {:ok, chunk} = Parser.parse("local x, y, z = 1, 2, 3") + + assert %{ + block: %{ + stmts: [ + %Stmt.Local{ + names: ["x", "y", "z"], + values: [ + %Expr.Number{value: 1}, + %Expr.Number{value: 2}, + %Expr.Number{value: 3} + ] + } + ] + } + } = chunk + end + + test "parses local function" do + assert {:ok, chunk} = + Parser.parse(""" + local function add(a, b) + return a + b + end + """) + + assert %{block: %{stmts: [%Stmt.LocalFunc{name: "add", params: ["a", "b"]}]}} = chunk + end + end + + describe "assignments" do + test "parses simple assignment" do + assert {:ok, chunk} = Parser.parse("x = 42") + + assert %{ + block: %{ + stmts: [ + %Stmt.Assign{ + targets: [%Expr.Var{name: "x"}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + + test "parses multiple assignment" do + assert {:ok, chunk} = Parser.parse("x, y = 1, 2") + + assert %{ + block: %{ + stmts: [ + %Stmt.Assign{ + targets: [%Expr.Var{name: "x"}, %Expr.Var{name: "y"}], + values: [%Expr.Number{value: 1}, %Expr.Number{value: 2}] + } + ] + } + } = chunk + end + + test "parses table field assignment" do + assert {:ok, chunk} = Parser.parse("t.field = 42") + + assert %{ + block: %{ + stmts: [ + %Stmt.Assign{ + targets: [%Expr.Property{}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + + test "parses indexed assignment" do + assert {:ok, chunk} = Parser.parse("t[1] = 42") + + assert %{ + block: %{ + stmts: [ + %Stmt.Assign{ + targets: [%Expr.Index{}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + end + + describe "function calls as statements" do + test "parses function call statement" do + assert {:ok, chunk} = Parser.parse("print(42)") + + assert %{ + block: %{ + stmts: [%Stmt.CallStmt{call: %Expr.Call{func: %Expr.Var{name: "print"}}}] + } + } = chunk + end + + test "parses method call statement" do + assert {:ok, chunk} = Parser.parse("obj:method()") + + assert %{ + block: %{stmts: [%Stmt.CallStmt{call: %Expr.MethodCall{method: "method"}}]} + } = chunk + end + end + + describe "if statements" do + test "parses simple if statement" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return x + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Stmt.Return{}]}, + elseifs: [], + else_block: nil + } + ] + } + } = chunk + end + + test "parses if with else" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return x + else + return 0 + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Stmt.Return{}]}, + elseifs: [], + else_block: %{stmts: [%Stmt.Return{}]} + } + ] + } + } = chunk + end + + test "parses if with elseif" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return 1 + elseif x < 0 then + return -1 + else + return 0 + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Stmt.Return{}]}, + elseifs: [{%Expr.BinOp{op: :lt}, %{stmts: [%Stmt.Return{}]}}], + else_block: %{stmts: [%Stmt.Return{}]} + } + ] + } + } = chunk + end + + test "parses if with multiple elseifs" do + assert {:ok, chunk} = + Parser.parse(""" + if x == 1 then + return "one" + elseif x == 2 then + return "two" + elseif x == 3 then + return "three" + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.If{ + elseifs: [_, _] + } + ] + } + } = chunk + end + end + + describe "while loops" do + test "parses while loop" do + assert {:ok, chunk} = + Parser.parse(""" + while x > 0 do + x = x - 1 + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.While{ + condition: %Expr.BinOp{op: :gt}, + body: %{stmts: [%Stmt.Assign{}]} + } + ] + } + } = chunk + end + end + + describe "repeat-until loops" do + test "parses repeat-until loop" do + assert {:ok, chunk} = + Parser.parse(""" + repeat + x = x - 1 + until x == 0 + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.Repeat{ + body: %{stmts: [%Stmt.Assign{}]}, + condition: %Expr.BinOp{op: :eq} + } + ] + } + } = chunk + end + end + + describe "for loops" do + test "parses numeric for loop" do + assert {:ok, chunk} = + Parser.parse(""" + for i = 1, 10 do + print(i) + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: nil, + body: %{stmts: [%Stmt.CallStmt{}]} + } + ] + } + } = chunk + end + + test "parses numeric for loop with step" do + assert {:ok, chunk} = + Parser.parse(""" + for i = 1, 10, 2 do + print(i) + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: %Expr.Number{value: 2} + } + ] + } + } = chunk + end + + test "parses generic for loop" do + assert {:ok, chunk} = + Parser.parse(""" + for k, v in pairs(t) do + print(k, v) + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.ForIn{ + vars: ["k", "v"], + iterators: [%Expr.Call{}], + body: %{stmts: [%Stmt.CallStmt{}]} + } + ] + } + } = chunk + end + + test "parses generic for loop with single variable" do + assert {:ok, chunk} = + Parser.parse(""" + for line in io.lines() do + print(line) + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.ForIn{ + vars: ["line"], + iterators: [%Expr.Call{func: %Expr.Property{}}] + } + ] + } + } = chunk + end + end + + describe "function declarations" do + test "parses simple function declaration" do + assert {:ok, chunk} = + Parser.parse(""" + function add(a, b) + return a + b + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.FuncDecl{ + name: ["add"], + params: ["a", "b"], + is_method: false, + body: %{stmts: [%Stmt.Return{}]} + } + ] + } + } = chunk + end + + test "parses function declaration with dot notation" do + assert {:ok, chunk} = + Parser.parse(""" + function math.abs(x) + return x + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.FuncDecl{ + name: ["math", "abs"], + is_method: false + } + ] + } + } = chunk + end + + test "parses method declaration" do + assert {:ok, chunk} = + Parser.parse(""" + function obj:method(x) + return self.field + x + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.FuncDecl{ + name: ["obj", "method"], + is_method: true, + params: ["x"] + } + ] + } + } = chunk + end + + test "parses nested function name" do + assert {:ok, chunk} = + Parser.parse(""" + function a.b.c.d() + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.FuncDecl{ + name: ["a", "b", "c", "d"], + is_method: false + } + ] + } + } = chunk + end + end + + describe "do blocks" do + test "parses do block" do + assert {:ok, chunk} = + Parser.parse(""" + do + local x = 42 + print(x) + end + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.Do{ + body: %{stmts: [%Stmt.Local{}, %Stmt.CallStmt{}]} + } + ] + } + } = chunk + end + end + + describe "break and goto" do + test "parses break" do + assert {:ok, chunk} = Parser.parse("break") + assert %{block: %{stmts: [%Stmt.Break{}]}} = chunk + end + + test "parses goto" do + assert {:ok, chunk} = Parser.parse("goto finish") + assert %{block: %{stmts: [%Stmt.Goto{label: "finish"}]}} = chunk + end + + test "parses label" do + assert {:ok, chunk} = Parser.parse("::finish::") + assert %{block: %{stmts: [%Stmt.Label{name: "finish"}]}} = chunk + end + + test "parses goto and label together" do + assert {:ok, chunk} = + Parser.parse(""" + goto skip + print("skipped") + ::skip:: + print("not skipped") + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.Goto{label: "skip"}, + %Stmt.CallStmt{}, + %Stmt.Label{name: "skip"}, + %Stmt.CallStmt{} + ] + } + } = chunk + end + end + + describe "return statements" do + test "parses return with no values" do + assert {:ok, chunk} = Parser.parse("return") + assert %{block: %{stmts: [%Stmt.Return{values: []}]}} = chunk + end + + test "parses return with single value" do + assert {:ok, chunk} = Parser.parse("return 42") + + assert %{ + block: %{stmts: [%Stmt.Return{values: [%Expr.Number{value: 42}]}]} + } = chunk + end + + test "parses return with multiple values" do + assert {:ok, chunk} = Parser.parse("return 1, 2, 3") + + assert %{ + block: %{ + stmts: [ + %Stmt.Return{ + values: [ + %Expr.Number{value: 1}, + %Expr.Number{value: 2}, + %Expr.Number{value: 3} + ] + } + ] + } + } = chunk + end + end + + describe "complex programs" do + test "parses factorial function" do + assert {:ok, chunk} = + Parser.parse(""" + function factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + end + end + """) + + assert %{block: %{stmts: [%Stmt.FuncDecl{name: ["factorial"]}]}} = chunk + end + + test "parses multiple statements" do + assert {:ok, chunk} = + Parser.parse(""" + local x = 10 + local y = 20 + local sum = x + y + print(sum) + """) + + assert %{ + block: %{ + stmts: [ + %Stmt.Local{names: ["x"]}, + %Stmt.Local{names: ["y"]}, + %Stmt.Local{names: ["sum"]}, + %Stmt.CallStmt{} + ] + } + } = chunk + end + + test "parses nested control structures" do + assert {:ok, _chunk} = + Parser.parse(""" + for i = 1, 10 do + if i % 2 == 0 then + print("even") + else + print("odd") + end + end + """) + end + end +end From 583310f88e204e280c05707a5ff85d55689b106d Mon Sep 17 00:00:00 2001 From: Dave Lucia Date: Thu, 5 Feb 2026 15:14:22 -0500 Subject: [PATCH 2/4] more tests --- test/lua/ast/meta_test.exs | 179 +++++++++++++++++ test/lua/parser/recovery_test.exs | 310 ++++++++++++++++++++++++++++++ 2 files changed, 489 insertions(+) create mode 100644 test/lua/ast/meta_test.exs create mode 100644 test/lua/parser/recovery_test.exs diff --git a/test/lua/ast/meta_test.exs b/test/lua/ast/meta_test.exs new file mode 100644 index 0000000..2ce9dd5 --- /dev/null +++ b/test/lua/ast/meta_test.exs @@ -0,0 +1,179 @@ +defmodule Lua.AST.MetaTest do + use ExUnit.Case, async: true + + alias Lua.AST.Meta + + describe "new/0" do + test "creates empty meta" do + meta = Meta.new() + assert %Meta{start: nil, end: nil, metadata: %{}} = meta + end + end + + describe "new/2" do + test "creates meta with start and end positions" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + + meta = Meta.new(start, finish) + + assert meta.start == start + assert meta.end == finish + assert meta.metadata == %{} + end + + test "accepts nil positions" do + meta = Meta.new(nil, nil) + assert meta.start == nil + assert meta.end == nil + end + end + + describe "new/3" do + test "creates meta with metadata" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + metadata = %{type: :test, custom: "value"} + + meta = Meta.new(start, finish, metadata) + + assert meta.start == start + assert meta.end == finish + assert meta.metadata == metadata + end + + test "accepts empty metadata map" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + + meta = Meta.new(start, finish, %{}) + + assert meta.metadata == %{} + end + end + + describe "add_metadata/3" do + test "adds metadata to existing meta" do + meta = Meta.new() + meta = Meta.add_metadata(meta, :test, "value") + + assert meta.metadata == %{test: "value"} + end + + test "adds multiple metadata fields" do + meta = Meta.new(nil, nil, %{a: 1}) + meta = Meta.add_metadata(meta, :b, 2) + + assert meta.metadata == %{a: 1, b: 2} + end + + test "overwrites existing keys" do + meta = Meta.new(nil, nil, %{a: 1}) + meta = Meta.add_metadata(meta, :a, 2) + + assert meta.metadata == %{a: 2} + end + end + + describe "merge/2" do + test "merges two metas taking earliest start" do + meta1 = Meta.new(%{line: 1, column: 5, byte_offset: 10}, %{line: 1, column: 10, byte_offset: 20}) + meta2 = Meta.new(%{line: 1, column: 1, byte_offset: 5}, %{line: 1, column: 8, byte_offset: 15}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 5} + end + + test "merges two metas taking latest end" do + meta1 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + meta2 = Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 20, byte_offset: 19}) + + merged = Meta.merge(meta1, meta2) + + assert merged.end == %{line: 1, column: 20, byte_offset: 19} + end + + test "handles nil positions" do + meta1 = Meta.new(nil, %{line: 1, column: 10, byte_offset: 9}) + meta2 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, nil) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + end + + describe "position tracking" do + test "stores line numbers" do + meta = Meta.new(%{line: 5, column: 10, byte_offset: 50}, %{line: 5, column: 20, byte_offset: 60}) + + assert meta.start.line == 5 + assert meta.end.line == 5 + end + + test "stores column numbers" do + meta = Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 15, byte_offset: 14}) + + assert meta.start.column == 5 + assert meta.end.column == 15 + end + + test "stores byte offsets" do + meta = Meta.new(%{line: 1, column: 1, byte_offset: 100}, %{line: 1, column: 10, byte_offset: 200}) + + assert meta.start.byte_offset == 100 + assert meta.end.byte_offset == 200 + end + + test "handles multiline spans" do + meta = Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 10, column: 5, byte_offset: 150}) + + assert meta.start.line == 1 + assert meta.end.line == 10 + end + end + + describe "metadata storage" do + test "stores arbitrary data" do + meta = Meta.new(nil, nil, %{ + node_type: :function, + name: "test", + params: ["a", "b"], + is_async: false + }) + + assert meta.metadata.node_type == :function + assert meta.metadata.name == "test" + assert meta.metadata.params == ["a", "b"] + assert meta.metadata.is_async == false + end + + test "stores nested data structures" do + meta = Meta.new(nil, nil, %{ + scope: %{ + variables: ["x", "y"], + functions: ["f", "g"] + } + }) + + assert meta.metadata.scope.variables == ["x", "y"] + assert meta.metadata.scope.functions == ["f", "g"] + end + end + + describe "struct validation" do + test "has correct struct fields" do + meta = %Meta{} + assert Map.has_key?(meta, :start) + assert Map.has_key?(meta, :end) + assert Map.has_key?(meta, :metadata) + end + + test "is a proper struct" do + meta = %Meta{} + assert meta.__struct__ == Lua.AST.Meta + end + end +end diff --git a/test/lua/parser/recovery_test.exs b/test/lua/parser/recovery_test.exs new file mode 100644 index 0000000..7f44314 --- /dev/null +++ b/test/lua/parser/recovery_test.exs @@ -0,0 +1,310 @@ +defmodule Lua.Parser.RecoveryTest do + use ExUnit.Case, async: true + + alias Lua.Parser.Recovery + alias Lua.Parser.Error + + describe "recover_at_statement/2" do + test "recovers at semicolon" do + tokens = [ + {:delimiter, :semicolon, %{line: 1, column: 5}}, + {:keyword, :end, %{line: 1, column: 7}}, + {:eof, %{line: 1, column: 10}} + ] + + error = Error.new(:unexpected_token, "Unexpected token", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:delimiter, :semicolon, _} | _] = rest + end + + test "recovers at end keyword" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + error = Error.new(:unexpected_token, "Test error", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:keyword, :end, _} | _] = rest + end + + test "recovers at statement keywords" do + keywords = [:if, :while, :for, :function, :local, :do, :repeat] + + for kw <- keywords do + tokens = [ + {:keyword, kw, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 10}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + assert {:recovered, _, [^error]} = Recovery.recover_at_statement(tokens, error) + end + end + + test "recovers when only EOF remains (EOF is a boundary)" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:eof, %{line: 1, column: 5}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + # EOF is actually a statement boundary, so this recovers + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert length(rest) >= 1 + end + + test "recovers at EOF" do + tokens = [{:eof, %{line: 1, column: 1}}] + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:eof, _}] = rest + end + end + + describe "recover_unclosed_delimiter/3" do + test "finds closing parenthesis" do + tokens = [ + {:delimiter, :rparen, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed (", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:eof, _}] = rest + end + + test "finds closing bracket" do + tokens = [ + {:delimiter, :rbracket, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed [", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lbracket, error) + assert [{:eof, _}] = rest + end + + test "finds closing brace" do + tokens = [ + {:delimiter, :rbrace, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed {", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lbrace, error) + assert [{:eof, _}] = rest + end + + test "handles nested delimiters" do + tokens = [ + {:delimiter, :lparen, %{line: 1, column: 2}}, + {:delimiter, :rparen, %{line: 1, column: 3}}, + {:delimiter, :rparen, %{line: 1, column: 4}}, + {:eof, %{line: 1, column: 5}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:eof, _}] = rest + end + + test "falls back to statement boundary if delimiter not found" do + tokens = [ + {:keyword, :end, %{line: 1, column: 5}}, + {:eof, %{line: 1, column: 8}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:keyword, :end, _} | _] = rest + end + end + + describe "recover_missing_keyword/3" do + test "finds the missing keyword" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:keyword, :then, %{line: 1, column: 3}}, + {:keyword, :end, %{line: 1, column: 8}} + ] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:keyword, :then, _} | _] = rest + end + + test "falls back to statement boundary if keyword not found" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:keyword, :end, _} | _] = rest + end + end + + describe "skip_to_statement/1" do + test "skips to statement boundary" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:keyword, :end, %{line: 1, column: 5}}, + {:eof, %{line: 1, column: 8}} + ] + + rest = Recovery.skip_to_statement(tokens) + assert [{:keyword, :end, _} | _] = rest + end + + test "returns empty list if no boundary found" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}} + ] + + assert [] = Recovery.skip_to_statement(tokens) + end + + test "returns tokens if already at boundary" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + assert ^tokens = Recovery.skip_to_statement(tokens) + end + end + + describe "is_statement_boundary?/1" do + test "recognizes delimiters" do + assert Recovery.is_statement_boundary?({:delimiter, :semicolon, %{line: 1, column: 1}}) + end + + test "recognizes block terminators" do + terminators = [:end, :else, :elseif, :until] + + for term <- terminators do + assert Recovery.is_statement_boundary?({:keyword, term, %{line: 1, column: 1}}) + end + end + + test "recognizes statement starters" do + starters = [:if, :while, :for, :function, :local, :do, :repeat] + + for starter <- starters do + assert Recovery.is_statement_boundary?({:keyword, starter, %{line: 1, column: 1}}) + end + end + + test "recognizes EOF" do + assert Recovery.is_statement_boundary?({:eof, %{line: 1, column: 1}}) + end + + test "rejects non-boundaries" do + refute Recovery.is_statement_boundary?({:identifier, "x", %{line: 1, column: 1}}) + refute Recovery.is_statement_boundary?({:number, 42, %{line: 1, column: 1}}) + refute Recovery.is_statement_boundary?({:operator, :add, %{line: 1, column: 1}}) + end + end + + describe "DelimiterStack" do + alias Recovery.DelimiterStack + + test "creates empty stack" do + stack = DelimiterStack.new() + assert DelimiterStack.empty?(stack) + end + + test "pushes delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + refute DelimiterStack.empty?(stack) + end + + test "pops matching delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :rparen) + assert DelimiterStack.empty?(stack) + end + + test "fails on mismatched delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:error, :mismatched, :lparen} = DelimiterStack.pop(stack, :rbracket) + end + + test "fails on empty stack" do + stack = DelimiterStack.new() + assert {:error, :empty} = DelimiterStack.pop(stack, :rparen) + end + + test "peeks at top delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:ok, :lparen, %{line: 1, column: 1}} = DelimiterStack.peek(stack) + end + + test "peek returns empty on empty stack" do + stack = DelimiterStack.new() + assert :empty = DelimiterStack.peek(stack) + end + + test "handles nested delimiters" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + stack = DelimiterStack.push(stack, :lbracket, %{line: 1, column: 5}) + stack = DelimiterStack.push(stack, :lbrace, %{line: 1, column: 10}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :rbrace) + assert {:ok, stack} = DelimiterStack.pop(stack, :rbracket) + assert {:ok, stack} = DelimiterStack.pop(stack, :rparen) + assert DelimiterStack.empty?(stack) + end + + test "handles keyword delimiters" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :function, %{line: 1, column: 1}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :end) + assert DelimiterStack.empty?(stack) + end + + test "matches all delimiter pairs" do + pairs = [ + {:lparen, :rparen}, + {:lbracket, :rbracket}, + {:lbrace, :rbrace}, + {:function, :end}, + {:if, :end}, + {:while, :end}, + {:for, :end}, + {:do, :end} + ] + + for {open, close} <- pairs do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, open, %{line: 1, column: 1}) + assert {:ok, _} = DelimiterStack.pop(stack, close) + end + end + end +end From 8c7e857b3a1af7532cd4d7728a4fa656820cc8a6 Mon Sep 17 00:00:00 2001 From: Dave Lucia Date: Thu, 5 Feb 2026 16:48:09 -0500 Subject: [PATCH 3/4] test cleanup --- lib/lua/ast/builder.ex | 6 +- lib/lua/ast/pretty_printer.ex | 20 +- lib/lua/ast/walker.ex | 130 ++- lib/lua/lexer.ex | 13 +- lib/lua/parser.ex | 28 +- lib/lua/parser/error.ex | 3 +- lib/lua/parser/pratt.ex | 3 +- test/lua/ast/builder_test.exs | 417 ++++--- test/lua/ast/meta_test.exs | 114 +- test/lua/ast/pretty_printer_test.exs | 919 ++++++++++++++-- test/lua/ast/walker_test.exs | 1221 ++++++++++++++++++--- test/lua/lexer_test.exs | 222 +++- test/lua/parser/beautiful_errors_test.exs | 269 ----- test/lua/parser/error_test.exs | 190 ++-- test/lua/parser/error_unit_test.exs | 627 +++++++++++ test/lua/parser/pratt_test.exs | 253 +++++ test/lua/parser/precedence_test.exs | 3 +- test/lua/parser/recovery_test.exs | 121 +- test/lua/runtime_exception_test.exs | 434 ++++++++ 19 files changed, 4131 insertions(+), 862 deletions(-) delete mode 100644 test/lua/parser/beautiful_errors_test.exs create mode 100644 test/lua/parser/error_unit_test.exs create mode 100644 test/lua/parser/pratt_test.exs create mode 100644 test/lua/runtime_exception_test.exs diff --git a/lib/lua/ast/builder.ex b/lib/lua/ast/builder.ex index c3b8efe..4d53ffc 100644 --- a/lib/lua/ast/builder.ex +++ b/lib/lua/ast/builder.ex @@ -201,7 +201,8 @@ defmodule Lua.AST.Builder do {:record, string("y"), number(20)} ]) """ - @spec table([{:list, Expr.t()} | {:record, Expr.t(), Expr.t()}], Meta.t() | nil) :: Expr.Table.t() + @spec table([{:list, Expr.t()} | {:record, Expr.t(), Expr.t()}], Meta.t() | nil) :: + Expr.Table.t() def table(fields, meta \\ nil) do %Expr.Table{ fields: fields, @@ -361,7 +362,8 @@ defmodule Lua.AST.Builder do # function math.add(a, b) return a + b end func_decl(["math", "add"], ["a", "b"], [...]) """ - @spec func_decl(String.t() | [String.t()], [String.t()], [Stmt.t()], keyword()) :: Stmt.FuncDecl.t() + @spec func_decl(String.t() | [String.t()], [String.t()], [Stmt.t()], keyword()) :: + Stmt.FuncDecl.t() def func_decl(name, params, body_stmts, opts \\ []) do name_parts = if is_binary(name), do: [name], else: name diff --git a/lib/lua/ast/pretty_printer.ex b/lib/lua/ast/pretty_printer.ex index c326632..a97219f 100644 --- a/lib/lua/ast/pretty_printer.ex +++ b/lib/lua/ast/pretty_printer.ex @@ -219,7 +219,16 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}#{do_print(call, level, indent_size)}" end - defp do_print(%Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block}, level, indent_size) do + defp do_print( + %Stmt.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + }, + level, + indent_size + ) do cond_str = do_print(cond, level, indent_size) then_str = print_block_body(then_block, level + 1, indent_size) @@ -264,7 +273,11 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}repeat\n#{body_str}#{indent(level, indent_size)}until #{cond_str}" end - defp do_print(%Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body}, level, indent_size) do + defp do_print( + %Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body}, + level, + indent_size + ) do start_str = do_print(start, level, indent_size) limit_str = do_print(limit, level, indent_size) body_str = print_block_body(body, level + 1, indent_size) @@ -357,7 +370,8 @@ defmodule Lua.AST.PrettyPrinter do %Expr.UnOp{} -> # Unary ops have high precedence, rarely need parens case parent_op do - :pow -> true # -2^3 should be -(2^3) + # -2^3 should be -(2^3) + :pow -> true _ -> false end diff --git a/lib/lua/ast/walker.ex b/lib/lua/ast/walker.ex index 5adc6cc..ddd0aed 100644 --- a/lib/lua/ast/walker.ex +++ b/lib/lua/ast/walker.ex @@ -168,7 +168,11 @@ defmodule Lua.AST.Walker do # Statements %Stmt.Assign{targets: targets, values: values} = stmt -> - %{stmt | targets: Enum.map(targets, &do_map(&1, mapper)), values: Enum.map(values, &do_map(&1, mapper))} + %{ + stmt + | targets: Enum.map(targets, &do_map(&1, mapper)), + values: Enum.map(values, &do_map(&1, mapper)) + } %Stmt.Local{values: values} = stmt when is_list(values) -> %{stmt | values: Enum.map(values, &do_map(&1, mapper))} @@ -185,15 +189,23 @@ defmodule Lua.AST.Walker do %Stmt.CallStmt{call: call} = stmt -> %{stmt | call: do_map(call, mapper)} - %Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block} = stmt -> - mapped_elseifs = Enum.map(elseifs, fn {c, b} -> {do_map(c, mapper), do_map(b, mapper)} end) + %Stmt.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + } = stmt -> + mapped_elseifs = + Enum.map(elseifs, fn {c, b} -> {do_map(c, mapper), do_map(b, mapper)} end) + mapped_else = if else_block, do: do_map(else_block, mapper), else: nil - %{stmt | - condition: do_map(cond, mapper), - then_block: do_map(then_block, mapper), - elseifs: mapped_elseifs, - else_block: mapped_else + %{ + stmt + | condition: do_map(cond, mapper), + then_block: do_map(then_block, mapper), + elseifs: mapped_elseifs, + else_block: mapped_else } %Stmt.While{condition: cond, body: body} = stmt -> @@ -202,18 +214,23 @@ defmodule Lua.AST.Walker do %Stmt.Repeat{body: body, condition: cond} = stmt -> %{stmt | body: do_map(body, mapper), condition: do_map(cond, mapper)} - %Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body} = stmt -> + %Stmt.ForNum{var: _var, start: start, limit: limit, step: step, body: body} = stmt -> mapped_step = if step, do: do_map(step, mapper), else: nil - %{stmt | - start: do_map(start, mapper), - limit: do_map(limit, mapper), - step: mapped_step, - body: do_map(body, mapper) + %{ + stmt + | start: do_map(start, mapper), + limit: do_map(limit, mapper), + step: mapped_step, + body: do_map(body, mapper) } - %Stmt.ForIn{vars: vars, iterators: iterators, body: body} = stmt -> - %{stmt | iterators: Enum.map(iterators, &do_map(&1, mapper)), body: do_map(body, mapper)} + %Stmt.ForIn{vars: _vars, iterators: iterators, body: body} = stmt -> + %{ + stmt + | iterators: Enum.map(iterators, &do_map(&1, mapper)), + body: do_map(body, mapper) + } %Stmt.Do{body: body} = stmt -> %{stmt | body: do_map(body, mapper)} @@ -241,40 +258,79 @@ defmodule Lua.AST.Walker do defp children(node) do case node do # Chunk - %Chunk{block: block} -> [block] + %Chunk{block: block} -> + [block] # Block - %Block{stmts: stmts} -> stmts + %Block{stmts: stmts} -> + stmts # Expressions with children - %Expr.BinOp{left: left, right: right} -> [left, right] - %Expr.UnOp{operand: operand} -> [operand] - %Expr.Table{fields: fields} -> extract_table_fields(fields) - %Expr.Call{func: func, args: args} -> [func | args] - %Expr.MethodCall{object: obj, args: args} -> [obj | args] - %Expr.Index{table: table, key: key} -> [table, key] - %Expr.Property{table: table} -> [table] - %Expr.Function{body: body} -> [body] + %Expr.BinOp{left: left, right: right} -> + [left, right] + + %Expr.UnOp{operand: operand} -> + [operand] + + %Expr.Table{fields: fields} -> + extract_table_fields(fields) + + %Expr.Call{func: func, args: args} -> + [func | args] + + %Expr.MethodCall{object: obj, args: args} -> + [obj | args] + + %Expr.Index{table: table, key: key} -> + [table, key] + + %Expr.Property{table: table} -> + [table] + + %Expr.Function{body: body} -> + [body] # Statements with children - %Stmt.Assign{targets: targets, values: values} -> targets ++ values - %Stmt.Local{values: values} when is_list(values) -> values - %Stmt.LocalFunc{body: body} -> [body] - %Stmt.FuncDecl{body: body} -> [body] - %Stmt.CallStmt{call: call} -> [call] + %Stmt.Assign{targets: targets, values: values} -> + targets ++ values + + %Stmt.Local{values: values} when is_list(values) -> + values + + %Stmt.LocalFunc{body: body} -> + [body] + + %Stmt.FuncDecl{body: body} -> + [body] + + %Stmt.CallStmt{call: call} -> + [call] + %Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block} -> elseif_nodes = Enum.flat_map(elseifs, fn {c, b} -> [c, b] end) [cond, then_block | elseif_nodes] ++ if(else_block, do: [else_block], else: []) - %Stmt.While{condition: cond, body: body} -> [cond, body] - %Stmt.Repeat{body: body, condition: cond} -> [body, cond] + + %Stmt.While{condition: cond, body: body} -> + [cond, body] + + %Stmt.Repeat{body: body, condition: cond} -> + [body, cond] + %Stmt.ForNum{start: start, limit: limit, step: step, body: body} -> [start, limit] ++ if(step, do: [step], else: []) ++ [body] - %Stmt.ForIn{iterators: iterators, body: body} -> iterators ++ [body] - %Stmt.Do{body: body} -> [body] - %Stmt.Return{values: values} -> values + + %Stmt.ForIn{iterators: iterators, body: body} -> + iterators ++ [body] + + %Stmt.Do{body: body} -> + [body] + + %Stmt.Return{values: values} -> + values # Leaf nodes (no children) - _ -> [] + _ -> + [] end end diff --git a/lib/lua/lexer.ex b/lib/lua/lexer.ex index 6774b77..03400de 100644 --- a/lib/lua/lexer.ex +++ b/lib/lua/lexer.ex @@ -249,11 +249,6 @@ defmodule Lua.Lexer do scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) end - defp scan_multiline_comment(rest, acc, pos, _level) do - # Not a multi-line comment after all, treat as single-line - scan_single_line_comment(rest, acc, pos) - end - defp scan_multiline_comment_content(<<"]", rest::binary>>, acc, pos, level) do case try_close_long_bracket(rest, level, 0) do {:ok, after_bracket} -> @@ -409,7 +404,7 @@ defmodule Lua.Lexer do defp scan_number(<>, num_acc, acc, pos, start_pos) do # Trailing dot is not part of the number - finalize_number(num_acc, <<".">> , acc, pos, start_pos) + finalize_number(num_acc, <<".">>, acc, pos, start_pos) end defp scan_number(<<".", c, rest::binary>>, num_acc, acc, pos, start_pos) @@ -508,10 +503,8 @@ defmodule Lua.Lexer do end true -> - case Integer.parse(num_str) do - {num, ""} -> {:ok, num} - _ -> {:error, :invalid_number} - end + {num, ""} = Integer.parse(num_str) + {:ok, num} end end diff --git a/lib/lua/parser.ex b/lib/lua/parser.ex index cdc684c..bc22f79 100644 --- a/lib/lua/parser.ex +++ b/lib/lua/parser.ex @@ -231,7 +231,8 @@ defmodule Lua.Parser do end _ -> - {:error, {:unexpected_token, peek(rest), "Expected identifier or 'function' after 'local'"}} + {:error, + {:unexpected_token, peek(rest), "Expected identifier or 'function' after 'local'"}} end end @@ -509,8 +510,7 @@ defmodule Lua.Parser do {:ok, %Stmt.CallStmt{call: call, meta: nil}, rest} _ -> - {:error, - {:unexpected_expression, "Expression statement must be a function call"}} + {:error, {:unexpected_expression, "Expression statement must be a function call"}} end end @@ -597,15 +597,15 @@ defmodule Lua.Parser do defp parse_prefix(tokens) do case peek(tokens) do # Literals - {:keyword, :nil, pos} -> + {:keyword, nil, pos} -> {_, rest} = consume(tokens) {:ok, %Expr.Nil{meta: Meta.new(pos)}, rest} - {:keyword, :true, pos} -> + {:keyword, true, pos} -> {_, rest} = consume(tokens) {:ok, %Expr.Bool{value: true, meta: Meta.new(pos)}, rest} - {:keyword, :false, pos} -> + {:keyword, false, pos} -> {_, rest} = consume(tokens) {:ok, %Expr.Bool{value: false, meta: Meta.new(pos)}, rest} @@ -652,6 +652,9 @@ defmodule Lua.Parser do {_, rest} = consume(tokens) parse_unary(:len, pos, rest) + {:eof, pos} -> + {:error, {:unexpected_token, :eof, pos, "Expected expression"}} + {type, _, pos} -> {:error, {:unexpected_token, type, pos, "Expected expression"}} @@ -1040,12 +1043,14 @@ defmodule Lua.Parser do {type, _, pos} when type != nil -> {:error, - {:unexpected_token, type, pos, "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} {type, pos} when is_map(pos) -> # Token without value (like :eof) {:error, - {:unexpected_token, type, pos, "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} nil -> {:error, {:unexpected_end, "Expected #{inspect(expected_type)}"}} @@ -1071,16 +1076,15 @@ defmodule Lua.Parser do "Expected #{inspect(expected_type)}:#{inspect(expected_value)}, got #{inspect(type)}"}} nil -> - {:error, {:unexpected_end, "Expected #{inspect(expected_type)}:#{inspect(expected_value)}"}} + {:error, + {:unexpected_end, "Expected #{inspect(expected_type)}:#{inspect(expected_value)}"}} end end # Error conversion helpers defp convert_error({:unexpected_token, type, pos, message}, _code) do - Error.new(:unexpected_token, message, pos, - suggestion: suggest_for_token_error(type, message) - ) + Error.new(:unexpected_token, message, pos, suggestion: suggest_for_token_error(type, message)) end defp convert_error({:unexpected_end, message}, _code) do diff --git a/lib/lua/parser/error.ex b/lib/lua/parser/error.ex index ec9345c..73430b3 100644 --- a/lib/lua/parser/error.ex +++ b/lib/lua/parser/error.ex @@ -184,7 +184,8 @@ defmodule Lua.Parser.Error do @spec format_multiple([t()], String.t()) :: String.t() def format_multiple(errors, source_code) do header = [ - IO.ANSI.red() <> IO.ANSI.bright() <> + IO.ANSI.red() <> + IO.ANSI.bright() <> "Found #{length(errors)} parse error#{if length(errors) == 1, do: "", else: "s"}" <> IO.ANSI.reset(), "" diff --git a/lib/lua/parser/pratt.ex b/lib/lua/parser/pratt.ex index 9ce6d4c..c005ef9 100644 --- a/lib/lua/parser/pratt.ex +++ b/lib/lua/parser/pratt.ex @@ -77,7 +77,8 @@ defmodule Lua.Parser.Pratt do """ @spec prefix_binding_power(atom()) :: non_neg_integer() | nil def prefix_binding_power(:not), do: 14 - def prefix_binding_power(:sub), do: 13 # Between mult (11) and power (16) + # Between mult (11) and power (16) + def prefix_binding_power(:sub), do: 13 def prefix_binding_power(:len), do: 14 def prefix_binding_power(_), do: nil diff --git a/test/lua/ast/builder_test.exs b/test/lua/ast/builder_test.exs index c33476f..ceaf392 100644 --- a/test/lua/ast/builder_test.exs +++ b/test/lua/ast/builder_test.exs @@ -57,24 +57,44 @@ defmodule Lua.AST.BuilderTest do test "creates chained property access" do prop = property(property(var("a"), "b"), "c") + assert %Expr.Property{ - table: %Expr.Property{ - table: %Expr.Var{name: "a"}, - field: "b" - }, - field: "c" - } = prop + table: %Expr.Property{ + table: %Expr.Var{name: "a"}, + field: "b" + }, + field: "c" + } = prop end end describe "operators" do test "creates binary operation" do op = binop(:add, number(2), number(3)) - assert %Expr.BinOp{op: :add, left: %Expr.Number{value: 2}, right: %Expr.Number{value: 3}} = op + + assert %Expr.BinOp{op: :add, left: %Expr.Number{value: 2}, right: %Expr.Number{value: 3}} = + op end test "creates all binary operators" do - ops = [:add, :sub, :mul, :div, :floor_div, :mod, :pow, :concat, :eq, :ne, :lt, :gt, :le, :ge, :and, :or] + ops = [ + :add, + :sub, + :mul, + :div, + :floor_div, + :mod, + :pow, + :concat, + :eq, + :ne, + :lt, + :gt, + :le, + :ge, + :and, + :or + ] for op <- ops do assert %Expr.BinOp{op: ^op} = binop(op, number(1), number(2)) @@ -95,11 +115,12 @@ defmodule Lua.AST.BuilderTest do test "creates nested operations" do # (2 + 3) * 4 op = binop(:mul, binop(:add, number(2), number(3)), number(4)) + assert %Expr.BinOp{ - op: :mul, - left: %Expr.BinOp{op: :add}, - right: %Expr.Number{value: 4} - } = op + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 4} + } = op end end @@ -110,27 +131,33 @@ defmodule Lua.AST.BuilderTest do end test "creates array-style table" do - tbl = table([ - {:list, number(1)}, - {:list, number(2)}, - {:list, number(3)} - ]) + tbl = + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + assert %Expr.Table{fields: [{:list, _}, {:list, _}, {:list, _}]} = tbl end test "creates record-style table" do - tbl = table([ - {:record, string("x"), number(10)}, - {:record, string("y"), number(20)} - ]) + tbl = + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + assert %Expr.Table{fields: [{:record, _, _}, {:record, _, _}]} = tbl end test "creates mixed table" do - tbl = table([ - {:list, number(1)}, - {:record, string("x"), number(10)} - ]) + tbl = + table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + assert %Expr.Table{fields: [{:list, _}, {:record, _, _}]} = tbl end end @@ -138,10 +165,11 @@ defmodule Lua.AST.BuilderTest do describe "function calls" do test "creates function call" do c = call(var("print"), [string("hello")]) + assert %Expr.Call{ - func: %Expr.Var{name: "print"}, - args: [%Expr.String{value: "hello"}] - } = c + func: %Expr.Var{name: "print"}, + args: [%Expr.String{value: "hello"}] + } = c end test "creates function call with multiple arguments" do @@ -151,21 +179,23 @@ defmodule Lua.AST.BuilderTest do test "creates method call" do mc = method_call(var("file"), "read", [string("*a")]) + assert %Expr.MethodCall{ - object: %Expr.Var{name: "file"}, - method: "read", - args: [%Expr.String{value: "*a"}] - } = mc + object: %Expr.Var{name: "file"}, + method: "read", + args: [%Expr.String{value: "*a"}] + } = mc end end describe "function expressions" do test "creates simple function" do fn_expr = function_expr(["x"], [return_stmt([var("x")])]) + assert %Expr.Function{ - params: ["x"], - body: %Block{stmts: [%Stmt.Return{}]} - } = fn_expr + params: ["x"], + body: %Block{stmts: [%Stmt.Return{}]} + } = fn_expr end test "creates function with multiple parameters" do @@ -182,10 +212,11 @@ defmodule Lua.AST.BuilderTest do describe "statements" do test "creates assignment" do stmt = assign([var("x")], [number(42)]) + assert %Stmt.Assign{ - targets: [%Expr.Var{name: "x"}], - values: [%Expr.Number{value: 42}] - } = stmt + targets: [%Expr.Var{name: "x"}], + values: [%Expr.Number{value: 42}] + } = stmt end test "creates multiple assignment" do @@ -205,11 +236,12 @@ defmodule Lua.AST.BuilderTest do test "creates local function" do stmt = local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Stmt.LocalFunc{ - name: "add", - params: ["a", "b"], - body: %Block{} - } = stmt + name: "add", + params: ["a", "b"], + body: %Block{} + } = stmt end test "creates function declaration with string name" do @@ -218,7 +250,9 @@ defmodule Lua.AST.BuilderTest do end test "creates function declaration with path name" do - stmt = func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + stmt = + func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Stmt.FuncDecl{name: ["math", "add"]} = stmt end @@ -254,95 +288,116 @@ defmodule Lua.AST.BuilderTest do describe "control flow" do test "creates if statement" do stmt = if_stmt(var("x"), [return_stmt([number(1)])]) + assert %Stmt.If{ - condition: %Expr.Var{name: "x"}, - then_block: %Block{stmts: [%Stmt.Return{}]}, - elseifs: [], - else_block: nil - } = stmt + condition: %Expr.Var{name: "x"}, + then_block: %Block{stmts: [%Stmt.Return{}]}, + elseifs: [], + else_block: nil + } = stmt end test "creates if-else statement" do - stmt = if_stmt( - var("x"), - [return_stmt([number(1)])], - else: [return_stmt([number(0)])] - ) + stmt = + if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + assert %Stmt.If{else_block: %Block{}} = stmt end test "creates if-elseif-else statement" do - stmt = if_stmt( - binop(:gt, var("x"), number(0)), - [return_stmt([number(1)])], - elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], - else: [return_stmt([number(0)])] - ) + stmt = + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + assert %Stmt.If{ - elseifs: [{_, %Block{}}], - else_block: %Block{} - } = stmt + elseifs: [{_, %Block{}}], + else_block: %Block{} + } = stmt end test "creates while loop" do - stmt = while_stmt(binop(:gt, var("x"), number(0)), [ - assign([var("x")], [binop(:sub, var("x"), number(1))]) - ]) + stmt = + while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) + assert %Stmt.While{ - condition: %Expr.BinOp{op: :gt}, - body: %Block{} - } = stmt + condition: %Expr.BinOp{op: :gt}, + body: %Block{} + } = stmt end test "creates repeat-until loop" do - stmt = repeat_stmt( - [assign([var("x")], [binop(:sub, var("x"), number(1))])], - binop(:le, var("x"), number(0)) - ) + stmt = + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + assert %Stmt.Repeat{ - body: %Block{}, - condition: %Expr.BinOp{op: :le} - } = stmt + body: %Block{}, + condition: %Expr.BinOp{op: :le} + } = stmt end test "creates numeric for loop" do - stmt = for_num("i", number(1), number(10), [ - call_stmt(call(var("print"), [var("i")])) - ]) + stmt = + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + assert %Stmt.ForNum{ - var: "i", - start: %Expr.Number{value: 1}, - limit: %Expr.Number{value: 10}, - step: nil, - body: %Block{} - } = stmt + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: nil, + body: %Block{} + } = stmt end test "creates numeric for loop with step" do - stmt = for_num("i", number(1), number(10), [ - call_stmt(call(var("print"), [var("i")])) - ], step: number(2)) + stmt = + for_num( + "i", + number(1), + number(10), + [ + call_stmt(call(var("print"), [var("i")])) + ], step: number(2)) + assert %Stmt.ForNum{step: %Expr.Number{value: 2}} = stmt end test "creates generic for loop" do - stmt = for_in( - ["k", "v"], - [call(var("pairs"), [var("t")])], - [call_stmt(call(var("print"), [var("k"), var("v")]))] - ) + stmt = + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + assert %Stmt.ForIn{ - vars: ["k", "v"], - iterators: [%Expr.Call{}], - body: %Block{} - } = stmt + vars: ["k", "v"], + iterators: [%Expr.Call{}], + body: %Block{} + } = stmt end test "creates do block" do - stmt = do_block([ - local(["x"], [number(10)]), - call_stmt(call(var("print"), [var("x")])) - ]) + stmt = + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + assert %Stmt.Do{body: %Block{stmts: [_, _]}} = stmt end end @@ -350,57 +405,59 @@ defmodule Lua.AST.BuilderTest do describe "complex structures" do test "builds nested function with closure" do # function outer(x) return function(y) return x + y end end - ast = chunk([ - func_decl("outer", ["x"], [ - return_stmt([ - function_expr(["y"], [ - return_stmt([binop(:add, var("x"), var("y"))]) + ast = + chunk([ + func_decl("outer", ["x"], [ + return_stmt([ + function_expr(["y"], [ + return_stmt([binop(:add, var("x"), var("y"))]) + ]) ]) ]) ]) - ]) assert %Chunk{ - block: %Block{ - stmts: [ - %Stmt.FuncDecl{ - name: ["outer"], - body: %Block{ - stmts: [ - %Stmt.Return{ - values: [%Expr.Function{}] - } - ] - } - } - ] - } - } = ast + block: %Block{ + stmts: [ + %Stmt.FuncDecl{ + name: ["outer"], + body: %Block{ + stmts: [ + %Stmt.Return{ + values: [%Expr.Function{}] + } + ] + } + } + ] + } + } = ast end test "builds complex if-elseif-else chain" do - ast = chunk([ - if_stmt( - binop(:gt, var("x"), number(0)), - [return_stmt([string("positive")])], - elseif: [ - {binop(:lt, var("x"), number(0)), [return_stmt([string("negative")])]}, - {binop(:eq, var("x"), number(0)), [return_stmt([string("zero")])]} - ], - else: [return_stmt([string("unknown")])] - ) - ]) + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([string("positive")])], + elseif: [ + {binop(:lt, var("x"), number(0)), [return_stmt([string("negative")])]}, + {binop(:eq, var("x"), number(0)), [return_stmt([string("zero")])]} + ], + else: [return_stmt([string("unknown")])] + ) + ]) assert %Chunk{ - block: %Block{ - stmts: [ - %Stmt.If{ - elseifs: [{_, _}, {_, _}], - else_block: %Block{} - } - ] - } - } = ast + block: %Block{ + stmts: [ + %Stmt.If{ + elseifs: [{_, _}, {_, _}], + else_block: %Block{} + } + ] + } + } = ast end test "builds nested loops" do @@ -409,25 +466,26 @@ defmodule Lua.AST.BuilderTest do # print(i * j) # end # end - ast = chunk([ - for_num("i", number(1), number(10), [ - for_num("j", number(1), number(10), [ - call_stmt(call(var("print"), [binop(:mul, var("i"), var("j"))])) + ast = + chunk([ + for_num("i", number(1), number(10), [ + for_num("j", number(1), number(10), [ + call_stmt(call(var("print"), [binop(:mul, var("i"), var("j"))])) + ]) ]) ]) - ]) assert %Chunk{ - block: %Block{ - stmts: [ - %Stmt.ForNum{ - body: %Block{ - stmts: [%Stmt.ForNum{}] - } - } - ] - } - } = ast + block: %Block{ + stmts: [ + %Stmt.ForNum{ + body: %Block{ + stmts: [%Stmt.ForNum{}] + } + } + ] + } + } = ast end test "builds table with complex expressions" do @@ -437,36 +495,39 @@ defmodule Lua.AST.BuilderTest do # [key] = value, # nested = {a = 1, b = 2} # } - tbl = table([ - {:record, string("x"), binop(:add, number(1), number(2))}, - {:record, string("y"), call(var("func"), [])}, - {:record, var("key"), var("value")}, - {:record, string("nested"), table([ - {:record, string("a"), number(1)}, - {:record, string("b"), number(2)} - ])} - ]) + tbl = + table([ + {:record, string("x"), binop(:add, number(1), number(2))}, + {:record, string("y"), call(var("func"), [])}, + {:record, var("key"), var("value")}, + {:record, string("nested"), + table([ + {:record, string("a"), number(1)}, + {:record, string("b"), number(2)} + ])} + ]) assert %Expr.Table{ - fields: [ - {:record, %Expr.String{value: "x"}, %Expr.BinOp{}}, - {:record, %Expr.String{value: "y"}, %Expr.Call{}}, - {:record, %Expr.Var{}, %Expr.Var{}}, - {:record, %Expr.String{value: "nested"}, %Expr.Table{}} - ] - } = tbl + fields: [ + {:record, %Expr.String{value: "x"}, %Expr.BinOp{}}, + {:record, %Expr.String{value: "y"}, %Expr.Call{}}, + {:record, %Expr.Var{}, %Expr.Var{}}, + {:record, %Expr.String{value: "nested"}, %Expr.Table{}} + ] + } = tbl end end describe "integration with parser" do test "builder output can be printed and reparsed" do # Build an AST using builder - ast = chunk([ - local(["x"], [number(10)]), - local(["y"], [number(20)]), - assign([var("z")], [binop(:add, var("x"), var("y"))]), - call_stmt(call(var("print"), [var("z")])) - ]) + ast = + chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]), + assign([var("z")], [binop(:add, var("x"), var("y"))]), + call_stmt(call(var("print"), [var("z")])) + ]) # Print it code = Lua.AST.PrettyPrinter.print(ast) diff --git a/test/lua/ast/meta_test.exs b/test/lua/ast/meta_test.exs index 2ce9dd5..526a5c0 100644 --- a/test/lua/ast/meta_test.exs +++ b/test/lua/ast/meta_test.exs @@ -77,8 +77,11 @@ defmodule Lua.AST.MetaTest do describe "merge/2" do test "merges two metas taking earliest start" do - meta1 = Meta.new(%{line: 1, column: 5, byte_offset: 10}, %{line: 1, column: 10, byte_offset: 20}) - meta2 = Meta.new(%{line: 1, column: 1, byte_offset: 5}, %{line: 1, column: 8, byte_offset: 15}) + meta1 = + Meta.new(%{line: 1, column: 5, byte_offset: 10}, %{line: 1, column: 10, byte_offset: 20}) + + meta2 = + Meta.new(%{line: 1, column: 1, byte_offset: 5}, %{line: 1, column: 8, byte_offset: 15}) merged = Meta.merge(meta1, meta2) @@ -86,15 +89,18 @@ defmodule Lua.AST.MetaTest do end test "merges two metas taking latest end" do - meta1 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) - meta2 = Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 20, byte_offset: 19}) + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + meta2 = + Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 20, byte_offset: 19}) merged = Meta.merge(meta1, meta2) assert merged.end == %{line: 1, column: 20, byte_offset: 19} end - test "handles nil positions" do + test "handles nil start positions" do meta1 = Meta.new(nil, %{line: 1, column: 10, byte_offset: 9}) meta2 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, nil) @@ -103,32 +109,96 @@ defmodule Lua.AST.MetaTest do assert merged.start == %{line: 1, column: 1, byte_offset: 0} assert merged.end == %{line: 1, column: 10, byte_offset: 9} end + + test "handles nil end positions" do + meta1 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, nil) + meta2 = Meta.new(nil, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "handles both positions nil on first meta" do + meta1 = Meta.new(nil, nil) + + meta2 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "handles both positions nil on second meta" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + meta2 = Meta.new(nil, nil) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "chooses earlier position when first has earlier byte offset" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + + meta2 = + Meta.new(%{line: 1, column: 6, byte_offset: 5}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "chooses later position when second has later byte offset" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + + meta2 = + Meta.new(%{line: 1, column: 2, byte_offset: 1}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end end describe "position tracking" do test "stores line numbers" do - meta = Meta.new(%{line: 5, column: 10, byte_offset: 50}, %{line: 5, column: 20, byte_offset: 60}) + meta = + Meta.new(%{line: 5, column: 10, byte_offset: 50}, %{line: 5, column: 20, byte_offset: 60}) assert meta.start.line == 5 assert meta.end.line == 5 end test "stores column numbers" do - meta = Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 15, byte_offset: 14}) + meta = + Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 15, byte_offset: 14}) assert meta.start.column == 5 assert meta.end.column == 15 end test "stores byte offsets" do - meta = Meta.new(%{line: 1, column: 1, byte_offset: 100}, %{line: 1, column: 10, byte_offset: 200}) + meta = + Meta.new(%{line: 1, column: 1, byte_offset: 100}, %{line: 1, column: 10, byte_offset: 200}) assert meta.start.byte_offset == 100 assert meta.end.byte_offset == 200 end test "handles multiline spans" do - meta = Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 10, column: 5, byte_offset: 150}) + meta = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 10, column: 5, byte_offset: 150}) assert meta.start.line == 1 assert meta.end.line == 10 @@ -137,12 +207,13 @@ defmodule Lua.AST.MetaTest do describe "metadata storage" do test "stores arbitrary data" do - meta = Meta.new(nil, nil, %{ - node_type: :function, - name: "test", - params: ["a", "b"], - is_async: false - }) + meta = + Meta.new(nil, nil, %{ + node_type: :function, + name: "test", + params: ["a", "b"], + is_async: false + }) assert meta.metadata.node_type == :function assert meta.metadata.name == "test" @@ -151,12 +222,13 @@ defmodule Lua.AST.MetaTest do end test "stores nested data structures" do - meta = Meta.new(nil, nil, %{ - scope: %{ - variables: ["x", "y"], - functions: ["f", "g"] - } - }) + meta = + Meta.new(nil, nil, %{ + scope: %{ + variables: ["x", "y"], + functions: ["f", "g"] + } + }) assert meta.metadata.scope.variables == ["x", "y"] assert meta.metadata.scope.functions == ["f", "g"] diff --git a/test/lua/ast/pretty_printer_test.exs b/test/lua/ast/pretty_printer_test.exs index c812f81..644fb9d 100644 --- a/test/lua/ast/pretty_printer_test.exs +++ b/test/lua/ast/pretty_printer_test.exs @@ -51,9 +51,13 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints chained property access" do - ast = chunk([return_stmt([ - property(property(var("a"), "b"), "c") - ])]) + ast = + chunk([ + return_stmt([ + property(property(var("a"), "b"), "c") + ]) + ]) + assert PrettyPrinter.print(ast) == "return a.b.c\n" end end @@ -77,29 +81,45 @@ defmodule Lua.AST.PrettyPrinterTest do test "handles operator precedence with parentheses" do # 2 + 3 * 4 should print as is (multiplication has higher precedence) - ast = chunk([return_stmt([ - binop(:add, number(2), binop(:mul, number(3), number(4))) - ])]) + ast = + chunk([ + return_stmt([ + binop(:add, number(2), binop(:mul, number(3), number(4))) + ]) + ]) + assert PrettyPrinter.print(ast) == "return 2 + 3 * 4\n" # (2 + 3) * 4 should have parentheses - ast = chunk([return_stmt([ - binop(:mul, binop(:add, number(2), number(3)), number(4)) - ])]) + ast = + chunk([ + return_stmt([ + binop(:mul, binop(:add, number(2), number(3)), number(4)) + ]) + ]) + assert PrettyPrinter.print(ast) == "return (2 + 3) * 4\n" end test "handles right-associative operators" do # 2 ^ 3 ^ 4 should print as 2 ^ 3 ^ 4 (right-associative) - ast = chunk([return_stmt([ - binop(:pow, number(2), binop(:pow, number(3), number(4))) - ])]) + ast = + chunk([ + return_stmt([ + binop(:pow, number(2), binop(:pow, number(3), number(4))) + ]) + ]) + assert PrettyPrinter.print(ast) == "return 2 ^ 3 ^ 4\n" # (2 ^ 3) ^ 4 should have parentheses - ast = chunk([return_stmt([ - binop(:pow, binop(:pow, number(2), number(3)), number(4)) - ])]) + ast = + chunk([ + return_stmt([ + binop(:pow, binop(:pow, number(2), number(3)), number(4)) + ]) + ]) + assert PrettyPrinter.print(ast) == "return (2 ^ 3) ^ 4\n" end @@ -119,33 +139,45 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints array-style table" do - ast = chunk([return_stmt([ - table([ - {:list, number(1)}, - {:list, number(2)}, - {:list, number(3)} + ast = + chunk([ + return_stmt([ + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + ]) ]) - ])]) + assert PrettyPrinter.print(ast) == "return {1, 2, 3}\n" end test "prints record-style table" do - ast = chunk([return_stmt([ - table([ - {:record, string("x"), number(10)}, - {:record, string("y"), number(20)} + ast = + chunk([ + return_stmt([ + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + ]) ]) - ])]) + assert PrettyPrinter.print(ast) == "return {x = 10, y = 20}\n" end test "prints mixed table fields" do - ast = chunk([return_stmt([ - table([ - {:list, number(1)}, - {:record, string("x"), number(10)} + ast = + chunk([ + return_stmt([ + table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + ]) ]) - ])]) + assert PrettyPrinter.print(ast) == "return {1, x = 10}\n" end end @@ -169,9 +201,13 @@ defmodule Lua.AST.PrettyPrinterTest do describe "function expressions" do test "prints simple function" do - ast = chunk([return_stmt([ - function_expr(["x"], [return_stmt([var("x")])]) - ])]) + ast = + chunk([ + return_stmt([ + function_expr(["x"], [return_stmt([var("x")])]) + ]) + ]) + result = PrettyPrinter.print(ast) assert result =~ "function(x)" assert result =~ "return x" @@ -179,18 +215,26 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints function with multiple parameters" do - ast = chunk([return_stmt([ - function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) - ])]) + ast = + chunk([ + return_stmt([ + function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + ]) + ]) + result = PrettyPrinter.print(ast) assert result =~ "function(a, b)" assert result =~ "return a + b" end test "prints function with vararg" do - ast = chunk([return_stmt([ - function_expr([], [return_stmt([vararg()])], vararg: true) - ])]) + ast = + chunk([ + return_stmt([ + function_expr([], [return_stmt([vararg()])], vararg: true) + ]) + ]) + result = PrettyPrinter.print(ast) assert result =~ "function(...)" end @@ -218,7 +262,9 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints local function" do - ast = chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + ast = + chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + result = PrettyPrinter.print(ast) assert result =~ "local function add(a, b)" assert result =~ "return a + b" @@ -226,7 +272,9 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints function declaration" do - ast = chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + ast = + chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + result = PrettyPrinter.print(ast) assert result =~ "function add(a, b)" assert result =~ "return a + b" @@ -249,9 +297,11 @@ defmodule Lua.AST.PrettyPrinterTest do describe "control flow" do test "prints if statement" do - ast = chunk([ - if_stmt(var("x"), [return_stmt([number(1)])]) - ]) + ast = + chunk([ + if_stmt(var("x"), [return_stmt([number(1)])]) + ]) + result = PrettyPrinter.print(ast) assert result =~ "if x then" assert result =~ "return 1" @@ -259,13 +309,15 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints if-else statement" do - ast = chunk([ - if_stmt( - var("x"), - [return_stmt([number(1)])], - else: [return_stmt([number(0)])] - ) - ]) + ast = + chunk([ + if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + ]) + result = PrettyPrinter.print(ast) assert result =~ "if x then" assert result =~ "else" @@ -273,14 +325,16 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints if-elseif-else statement" do - ast = chunk([ - if_stmt( - binop(:gt, var("x"), number(0)), - [return_stmt([number(1)])], - elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], - else: [return_stmt([number(0)])] - ) - ]) + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) + result = PrettyPrinter.print(ast) assert result =~ "if x > 0 then" assert result =~ "elseif x < 0 then" @@ -289,11 +343,13 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints while loop" do - ast = chunk([ - while_stmt(binop(:gt, var("x"), number(0)), [ - assign([var("x")], [binop(:sub, var("x"), number(1))]) + ast = + chunk([ + while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) ]) - ]) + result = PrettyPrinter.print(ast) assert result =~ "while x > 0 do" assert result =~ "x = x - 1" @@ -301,12 +357,14 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints repeat-until loop" do - ast = chunk([ - repeat_stmt( - [assign([var("x")], [binop(:sub, var("x"), number(1))])], - binop(:le, var("x"), number(0)) - ) - ]) + ast = + chunk([ + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + ]) + result = PrettyPrinter.print(ast) assert result =~ "repeat" assert result =~ "x = x - 1" @@ -314,11 +372,13 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints numeric for loop" do - ast = chunk([ - for_num("i", number(1), number(10), [ - call_stmt(call(var("print"), [var("i")])) + ast = + chunk([ + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) ]) - ]) + result = PrettyPrinter.print(ast) assert result =~ "for i = 1, 10 do" assert result =~ "print(i)" @@ -326,23 +386,33 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints numeric for loop with step" do - ast = chunk([ - for_num("i", number(1), number(10), [ - call_stmt(call(var("print"), [var("i")])) - ], step: number(2)) - ]) + ast = + chunk([ + for_num( + "i", + number(1), + number(10), + [ + call_stmt(call(var("print"), [var("i")])) + ], + step: number(2) + ) + ]) + result = PrettyPrinter.print(ast) assert result =~ "for i = 1, 10, 2 do" end test "prints generic for loop" do - ast = chunk([ - for_in( - ["k", "v"], - [call(var("pairs"), [var("t")])], - [call_stmt(call(var("print"), [var("k"), var("v")]))] - ) - ]) + ast = + chunk([ + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + ]) + result = PrettyPrinter.print(ast) assert result =~ "for k, v in pairs(t) do" assert result =~ "print(k, v)" @@ -350,12 +420,14 @@ defmodule Lua.AST.PrettyPrinterTest do end test "prints do block" do - ast = chunk([ - do_block([ - local(["x"], [number(10)]), - call_stmt(call(var("print"), [var("x")])) + ast = + chunk([ + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) ]) - ]) + result = PrettyPrinter.print(ast) assert result =~ "do" assert result =~ "local x = 10" @@ -366,13 +438,15 @@ defmodule Lua.AST.PrettyPrinterTest do describe "indentation" do test "indents nested blocks" do - ast = chunk([ - if_stmt(var("x"), [ - if_stmt(var("y"), [ - return_stmt([number(1)]) + ast = + chunk([ + if_stmt(var("x"), [ + if_stmt(var("y"), [ + return_stmt([number(1)]) + ]) ]) ]) - ]) + result = PrettyPrinter.print(ast) # Check that nested blocks are indented lines = String.split(result, "\n", trim: true) @@ -380,11 +454,13 @@ defmodule Lua.AST.PrettyPrinterTest do end test "respects custom indent size" do - ast = chunk([ - if_stmt(var("x"), [ - return_stmt([number(1)]) + ast = + chunk([ + if_stmt(var("x"), [ + return_stmt([number(1)]) + ]) ]) - ]) + result = PrettyPrinter.print(ast, indent: 4) assert result =~ " return 1" end @@ -422,4 +498,647 @@ defmodule Lua.AST.PrettyPrinterTest do assert ast.block.stmts |> length() == ast2.block.stmts |> length() end end + + describe "string escaping" do + test "escapes backslash" do + ast = chunk([return_stmt([string("path\\to\\file")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"path\\\\to\\\\file\"\n" + end + + test "escapes double quotes" do + ast = chunk([return_stmt([string("say \"hello\"")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"say \\\"hello\\\"\"\n" + end + + test "escapes tab character" do + ast = chunk([return_stmt([string("hello\tworld")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"hello\\tworld\"\n" + end + + test "escapes all special characters together" do + ast = chunk([return_stmt([string("line1\n\"quote\"\ttab\\back")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"line1\\n\\\"quote\\\"\\ttab\\\\back\"\n" + end + end + + describe "all binary operators" do + test "prints floor division" do + ast = chunk([return_stmt([binop(:floor_div, number(10), number(3))])]) + assert PrettyPrinter.print(ast) == "return 10 // 3\n" + end + + test "prints modulo" do + ast = chunk([return_stmt([binop(:mod, number(10), number(3))])]) + assert PrettyPrinter.print(ast) == "return 10 % 3\n" + end + + test "prints concatenation" do + ast = chunk([return_stmt([binop(:concat, string("hello"), string("world"))])]) + assert PrettyPrinter.print(ast) == "return \"hello\" .. \"world\"\n" + end + + test "prints equality" do + ast = chunk([return_stmt([binop(:eq, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x == y\n" + end + + test "prints inequality" do + ast = chunk([return_stmt([binop(:ne, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x ~= y\n" + end + + test "prints less than or equal" do + ast = chunk([return_stmt([binop(:le, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x <= y\n" + end + + test "prints greater than or equal" do + ast = chunk([return_stmt([binop(:ge, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x >= y\n" + end + + test "prints logical and" do + ast = chunk([return_stmt([binop(:and, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x and y\n" + end + + test "prints logical or" do + ast = chunk([return_stmt([binop(:or, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x or y\n" + end + end + + describe "operator associativity" do + test "handles concat right-associativity" do + # "a" .. "b" .. "c" should print as is (right-associative) + ast = + chunk([ + return_stmt([ + binop(:concat, string("a"), binop(:concat, string("b"), string("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return \"a\" .. \"b\" .. \"c\"\n" + + # ("a" .. "b") .. "c" should have parentheses + ast = + chunk([ + return_stmt([ + binop(:concat, binop(:concat, string("a"), string("b")), string("c")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (\"a\" .. \"b\") .. \"c\"\n" + end + + test "handles subtraction left-associativity" do + # (10 - 5) - 2 should print as is (left-associative) + ast = + chunk([ + return_stmt([ + binop(:sub, binop(:sub, number(10), number(5)), number(2)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 - 5 - 2\n" + + # 10 - (5 - 2) should have parentheses + ast = + chunk([ + return_stmt([ + binop(:sub, number(10), binop(:sub, number(5), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 - (5 - 2)\n" + end + end + + describe "table key formatting" do + test "uses bracket notation for non-identifier strings" do + ast = + chunk([ + return_stmt([ + table([ + {:record, string("not-valid"), number(1)}, + {:record, string("123"), number(2)}, + {:record, string("with space"), number(3)} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[\"not-valid\"] = 1" + assert result =~ "[\"123\"] = 2" + assert result =~ "[\"with space\"] = 3" + end + + test "uses bracket notation for Lua keywords" do + ast = + chunk([ + return_stmt([ + table([ + {:record, string("end"), number(1)}, + {:record, string("while"), number(2)}, + {:record, string("function"), number(3)} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[\"end\"] = 1" + assert result =~ "[\"while\"] = 2" + assert result =~ "[\"function\"] = 3" + end + + test "uses bracket notation for non-string keys" do + ast = + chunk([ + return_stmt([ + table([ + {:record, number(1), string("first")}, + {:record, var("x"), string("variable")} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[1] = \"first\"" + assert result =~ "[x] = \"variable\"" + end + end + + describe "dotted function names" do + test "prints function with dotted name" do + ast = + chunk([ + func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function math.add(a, b)" + end + + test "prints function with deeply nested name" do + ast = + chunk([ + func_decl(["a", "b", "c"], ["x"], [return_stmt([var("x")])]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function a.b.c(x)" + end + end + + describe "local statement variations" do + test "prints local with nil values" do + ast = chunk([local(["x"], nil)]) + assert PrettyPrinter.print(ast) == "local x\n" + end + + test "prints multiple local variables" do + ast = chunk([local(["x", "y", "z"], [])]) + assert PrettyPrinter.print(ast) == "local x, y, z\n" + end + + test "prints multiple local variables with values" do + ast = chunk([local(["x", "y"], [number(1), number(2)])]) + assert PrettyPrinter.print(ast) == "local x, y = 1, 2\n" + end + end + + describe "goto and label statements" do + test "prints goto statement" do + ast = chunk([goto_stmt("skip")]) + assert PrettyPrinter.print(ast) == "goto skip\n" + end + + test "prints label statement" do + ast = chunk([label("skip")]) + assert PrettyPrinter.print(ast) == "::skip::\n" + end + end + + describe "function calls with no arguments" do + test "prints function call with no arguments" do + ast = chunk([call_stmt(call(var("print"), []))]) + assert PrettyPrinter.print(ast) == "print()\n" + end + + test "prints method call with no arguments" do + ast = chunk([call_stmt(method_call(var("obj"), "method", []))]) + assert PrettyPrinter.print(ast) == "obj:method()\n" + end + end + + describe "if statement variations" do + test "prints if statement without elseif" do + ast = + chunk([ + if_stmt( + var("x"), + [return_stmt([number(1)])], + elseif: [] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "return 1" + assert result =~ "end" + end + + test "prints if statement with multiple elseifs" do + ast = + chunk([ + if_stmt( + binop(:eq, var("x"), number(1)), + [return_stmt([string("one")])], + elseif: [ + {binop(:eq, var("x"), number(2)), [return_stmt([string("two")])]}, + {binop(:eq, var("x"), number(3)), [return_stmt([string("three")])]} + ], + else: [return_stmt([string("other")])] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x == 1 then" + assert result =~ "elseif x == 2 then" + assert result =~ "elseif x == 3 then" + assert result =~ "else" + assert result =~ "end" + end + end + + describe "local function with vararg" do + test "prints local function with vararg parameter" do + ast = + chunk([ + local_func("variadic", [], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "local function variadic(...)" + assert result =~ "return ..." + end + + test "prints local function with mixed parameters and vararg" do + ast = + chunk([ + local_func("mixed", ["a", "b"], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "local function mixed(a, b, ...)" + end + end + + describe "function declaration with vararg" do + test "prints function declaration with vararg parameter" do + ast = + chunk([ + func_decl("variadic", [], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function variadic(...)" + assert result =~ "return ..." + end + end + + describe "precedence edge cases" do + test "handles division and multiplication precedence" do + # 10 / 2 * 3 (same precedence, left-to-right) + ast = + chunk([ + return_stmt([ + binop(:mul, binop(:div, number(10), number(2)), number(3)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 / 2 * 3\n" + end + + test "handles comparison operators with arithmetic" do + # 2 + 3 < 10 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:lt, binop(:add, number(2), number(3)), number(10)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 + 3 < 10\n" + end + + test "handles logical operators with comparisons" do + # x < 10 and y > 5 (comparison has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:and, binop(:lt, var("x"), number(10)), binop(:gt, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x < 10 and y > 5\n" + end + + test "handles or with and (and has higher precedence)" do + # a or b and c should print as a or (b and c) + ast = + chunk([ + return_stmt([ + binop(:or, var("a"), binop(:and, var("b"), var("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return a or b and c\n" + end + end + + describe "all comparison operators in precedence" do + test "handles all comparison operators" do + # Test that all comparison operators work together + ast = chunk([return_stmt([binop(:lt, var("a"), var("b"))])]) + assert PrettyPrinter.print(ast) == "return a < b\n" + + ast = chunk([return_stmt([binop(:gt, var("a"), var("b"))])]) + assert PrettyPrinter.print(ast) == "return a > b\n" + end + end + + describe "nested structures" do + test "handles deeply nested expressions" do + # ((a + b) * c) / d + ast = + chunk([ + return_stmt([ + binop( + :div, + binop(:mul, binop(:add, var("a"), var("b")), var("c")), + var("d") + ) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result == "return (a + b) * c / d\n" + end + + test "handles nested table access" do + # a[b[c]] + ast = + chunk([ + return_stmt([ + index(var("a"), index(var("b"), var("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return a[b[c]]\n" + end + + test "handles chained method calls" do + # obj:method1():method2() + ast = + chunk([ + return_stmt([ + method_call( + method_call(var("obj"), "method1", []), + "method2", + [] + ) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return obj:method1():method2()\n" + end + end + + describe "function name variations" do + test "prints simple function name" do + ast = chunk([func_decl("simple", [], [return_stmt([])])]) + result = PrettyPrinter.print(ast) + assert result =~ "function simple()" + end + end + + describe "number formatting edge cases" do + test "formats integer as integer" do + ast = chunk([return_stmt([number(42)])]) + assert PrettyPrinter.print(ast) == "return 42\n" + end + + test "formats float that equals integer with .0" do + ast = chunk([return_stmt([number(5.0)])]) + assert PrettyPrinter.print(ast) == "return 5.0\n" + end + + test "formats regular float normally" do + ast = chunk([return_stmt([number(3.14159)])]) + assert PrettyPrinter.print(ast) == "return 3.14159\n" + end + + test "formats negative numbers" do + ast = chunk([return_stmt([number(-42)])]) + assert PrettyPrinter.print(ast) == "return -42\n" + end + end + + describe "precedence with various operators" do + test "comparison operators with power" do + # x < y ^ 2 (power has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:lt, var("x"), binop(:pow, var("y"), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x < y ^ 2\n" + end + + test "equality with arithmetic" do + # x == y + 1 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:eq, var("x"), binop(:add, var("y"), number(1))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x == y + 1\n" + end + + test "inequality with arithmetic" do + # x ~= y * 2 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:ne, var("x"), binop(:mul, var("y"), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x ~= y * 2\n" + end + + test "less than or equal with addition" do + # x <= y + 5 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:le, var("x"), binop(:add, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x <= y + 5\n" + end + + test "greater than or equal with subtraction" do + # x >= y - 5 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:ge, var("x"), binop(:sub, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x >= y - 5\n" + end + + test "concat with comparison" do + # (x < y) .. "test" (comparison has lower precedence than concat) + ast = + chunk([ + return_stmt([ + binop(:concat, binop(:lt, var("x"), var("y")), string("test")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x < y) .. \"test\"\n" + end + + test "floor division with addition" do + # (x + y) // 2 (addition has lower precedence) + ast = + chunk([ + return_stmt([ + binop(:floor_div, binop(:add, var("x"), var("y")), number(2)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x + y) // 2\n" + end + + test "modulo with addition" do + # (x + y) % 10 (addition has lower precedence) + ast = + chunk([ + return_stmt([ + binop(:mod, binop(:add, var("x"), var("y")), number(10)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x + y) % 10\n" + end + + test "power with unary operator" do + # 2 ^ (-x) (unary needs parens with power parent) + ast = + chunk([ + return_stmt([ + binop(:pow, number(2), unop(:neg, var("x"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 ^ (-x)\n" + end + + test "unary not with non-power operator" do + # not x and y (unary has higher precedence than and) + ast = + chunk([ + return_stmt([ + binop(:and, unop(:not, var("x")), var("y")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return not x and y\n" + end + end + + describe "binary name for function declaration" do + test "prints function with string name" do + # When passing a simple string name (not a list) + alias Lua.AST.Stmt + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + %Stmt.FuncDecl{ + name: "simple", + params: [], + body: %Lua.AST.Block{stmts: [return_stmt([])]}, + is_method: false, + meta: nil + } + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "function simple()" + end + end + + describe "unknown operators (defensive default cases)" do + test "handles unknown binary operator" do + # Create a BinOp with an invalid operator to test the default case + alias Lua.AST.Expr + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + return_stmt([ + %Expr.BinOp{ + op: :unknown_op, + left: number(1), + right: number(2), + meta: nil + } + ]) + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "" + end + + test "handles unknown unary operator" do + # Create a UnOp with an invalid operator to test the default case + alias Lua.AST.Expr + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + return_stmt([ + %Expr.UnOp{ + op: :unknown_unop, + operand: number(42), + meta: nil + } + ]) + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "" + end + end end diff --git a/test/lua/ast/walker_test.exs b/test/lua/ast/walker_test.exs index e8b5429..292b4e8 100644 --- a/test/lua/ast/walker_test.exs +++ b/test/lua/ast/walker_test.exs @@ -7,11 +7,11 @@ defmodule Lua.AST.WalkerTest do describe "walk/2" do test "visits all nodes in pre-order" do # Build: local x = 2 + 3 - ast = chunk([ - local(["x"], [binop(:add, number(2), number(3))]) - ]) + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) - visited = [] ref = :erlang.make_ref() Walker.walk(ast, fn node -> @@ -28,16 +28,20 @@ defmodule Lua.AST.WalkerTest do test "visits all nodes in post-order" do # Build: local x = 2 + 3 - ast = chunk([ - local(["x"], [binop(:add, number(2), number(3))]) - ]) + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) - visited = [] ref = :erlang.make_ref() - Walker.walk(ast, fn node -> - send(self(), {ref, node}) - end, order: :post) + Walker.walk( + ast, + fn node -> + send(self(), {ref, node}) + end, + order: :post + ) visited = collect_messages(ref, []) @@ -49,14 +53,15 @@ defmodule Lua.AST.WalkerTest do test "walks through if statement with all branches" do # if x > 0 then return 1 elseif x < 0 then return -1 else return 0 end - ast = chunk([ - if_stmt( - binop(:gt, var("x"), number(0)), - [return_stmt([number(1)])], - elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], - else: [return_stmt([number(0)])] - ) - ]) + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) count = count_nodes(ast) # Chunk + Block + If + 3 conditions + 3 blocks + 3 return stmts + 3 values = many nodes @@ -65,100 +70,117 @@ defmodule Lua.AST.WalkerTest do test "walks through function expressions" do # local f = function(a, b) return a + b end - ast = chunk([ - local(["f"], [function_expr(["a", "b"], [ - return_stmt([binop(:add, var("a"), var("b"))]) - ])]) - ]) + ast = + chunk([ + local(["f"], [ + function_expr(["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + ]) + ]) # Count variable references - var_count = Walker.reduce(ast, 0, fn - %Expr.Var{}, acc -> acc + 1 - _, acc -> acc - end) + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) - assert var_count == 2 # a and b + # a and b + assert var_count == 2 end end describe "map/2" do test "transforms number literals" do # local x = 2 + 3 - ast = chunk([ - local(["x"], [binop(:add, number(2), number(3))]) - ]) + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) # Double all numbers - transformed = Walker.map(ast, fn - %Expr.Number{value: n} = node -> %{node | value: n * 2} - node -> node - end) + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) # Extract the numbers - numbers = Walker.reduce(transformed, [], fn - %Expr.Number{value: n}, acc -> [n | acc] - _, acc -> acc - end) + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) - assert Enum.sort(numbers) == [4, 6] # 2*2=4, 3*2=6 + # 2*2=4, 3*2=6 + assert Enum.sort(numbers) == [4, 6] end test "transforms variable names" do # x = y + z - ast = chunk([ - assign([var("x")], [binop(:add, var("y"), var("z"))]) - ]) + ast = + chunk([ + assign([var("x")], [binop(:add, var("y"), var("z"))]) + ]) # Add prefix to all variables - transformed = Walker.map(ast, fn - %Expr.Var{name: name} = node -> %{node | name: "local_" <> name} - node -> node - end) + transformed = + Walker.map(ast, fn + %Expr.Var{name: name} = node -> %{node | name: "local_" <> name} + node -> node + end) # Collect variable names - names = Walker.reduce(transformed, [], fn - %Expr.Var{name: name}, acc -> [name | acc] - _, acc -> acc - end) + names = + Walker.reduce(transformed, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) assert Enum.sort(names) == ["local_x", "local_y", "local_z"] end test "preserves structure while transforming" do # if true then print(1) end - ast = chunk([ - if_stmt(bool(true), [ - call_stmt(call(var("print"), [number(1)])) + ast = + chunk([ + if_stmt(bool(true), [ + call_stmt(call(var("print"), [number(1)])) + ]) ]) - ]) # Transform should preserve structure - transformed = Walker.map(ast, fn - %Expr.Number{value: n} = node -> %{node | value: n + 1} - node -> node - end) + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 1} + node -> node + end) # Extract the if statement [%Stmt.If{condition: %Expr.Bool{value: true}}] = transformed.block.stmts # Number should be transformed - numbers = Walker.reduce(transformed, [], fn - %Expr.Number{value: n}, acc -> [n | acc] - _, acc -> acc - end) + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) - assert numbers == [2] # 1 + 1 = 2 + # 1 + 1 = 2 + assert numbers == [2] end end describe "reduce/3" do test "counts all nodes" do # local x = 1; local y = 2; return x + y - ast = chunk([ - local(["x"], [number(1)]), - local(["y"], [number(2)]), - return_stmt([binop(:add, var("x"), var("y"))]) - ]) + ast = + chunk([ + local(["x"], [number(1)]), + local(["y"], [number(2)]), + return_stmt([binop(:add, var("x"), var("y"))]) + ]) count = Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) @@ -168,62 +190,72 @@ defmodule Lua.AST.WalkerTest do test "collects specific node types" do # local x = 1; y = 2; print(x, y) - ast = chunk([ - local(["x"], [number(1)]), - assign([var("y")], [number(2)]), - call_stmt(call(var("print"), [var("x"), var("y")])) - ]) + ast = + chunk([ + local(["x"], [number(1)]), + assign([var("y")], [number(2)]), + call_stmt(call(var("print"), [var("x"), var("y")])) + ]) # Collect all variable names - vars = Walker.reduce(ast, [], fn - %Expr.Var{name: name}, acc -> [name | acc] - _, acc -> acc - end) + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) assert Enum.sort(vars) == ["print", "x", "y", "y"] # Collect all numbers - numbers = Walker.reduce(ast, [], fn - %Expr.Number{value: n}, acc -> [n | acc] - _, acc -> acc - end) + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) assert Enum.sort(numbers) == [1, 2] end test "builds maps from nodes" do # local x = 10; local y = 20 - ast = chunk([ - local(["x"], [number(10)]), - local(["y"], [number(20)]) - ]) + ast = + chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]) + ]) # Build map of local declarations: name -> value - locals = Walker.reduce(ast, %{}, fn - %Stmt.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> - Map.put(acc, name, n) - _, acc -> - acc - end) + locals = + Walker.reduce(ast, %{}, fn + %Stmt.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> + Map.put(acc, name, n) + + _, acc -> + acc + end) assert locals == %{"x" => 10, "y" => 20} end test "accumulates deeply nested values" do # function f() return function() return 42 end end - ast = chunk([ - func_decl("f", [], [ - return_stmt([function_expr([], [ - return_stmt([number(42)]) - ])]) + ast = + chunk([ + func_decl("f", [], [ + return_stmt([ + function_expr([], [ + return_stmt([number(42)]) + ]) + ]) + ]) ]) - ]) # Count function expressions - func_count = Walker.reduce(ast, 0, fn - %Expr.Function{}, acc -> acc + 1 - _, acc -> acc - end) + func_count = + Walker.reduce(ast, 0, fn + %Expr.Function{}, acc -> acc + 1 + _, acc -> acc + end) assert func_count == 1 end @@ -236,20 +268,22 @@ defmodule Lua.AST.WalkerTest do # print(i) # end # end - ast = chunk([ - for_num("i", number(1), number(10), [ - if_stmt( - binop(:eq, binop(:mod, var("i"), number(2)), number(0)), - [call_stmt(call(var("print"), [var("i")]))] - ) + ast = + chunk([ + for_num("i", number(1), number(10), [ + if_stmt( + binop(:eq, binop(:mod, var("i"), number(2)), number(0)), + [call_stmt(call(var("print"), [var("i")]))] + ) + ]) ]) - ]) # Count all operators - ops = Walker.reduce(ast, [], fn - %Expr.BinOp{op: op}, acc -> [op | acc] - _, acc -> acc - end) + ops = + Walker.reduce(ast, [], fn + %Expr.BinOp{op: op}, acc -> [op | acc] + _, acc -> acc + end) assert :eq in ops assert :mod in ops @@ -257,26 +291,963 @@ defmodule Lua.AST.WalkerTest do test "handles table constructors" do # local t = {x = 1, y = 2, [3] = "three"} - ast = chunk([ - local(["t"], [ - table([ - {:record, string("x"), number(1)}, - {:record, string("y"), number(2)}, - {:record, number(3), string("three")} + ast = + chunk([ + local(["t"], [ + table([ + {:record, string("x"), number(1)}, + {:record, string("y"), number(2)}, + {:record, number(3), string("three")} + ]) ]) ]) - ]) # Count table fields - field_count = Walker.reduce(ast, 0, fn - %Expr.Table{fields: fields}, acc -> acc + length(fields) - _, acc -> acc - end) + field_count = + Walker.reduce(ast, 0, fn + %Expr.Table{fields: fields}, acc -> acc + length(fields) + _, acc -> acc + end) assert field_count == 3 end end + describe "expression nodes" do + test "walks MethodCall nodes" do + # obj:method(arg1, arg2) + ast = + chunk([ + call_stmt(method_call(var("obj"), "method", [var("arg1"), var("arg2")])) + ]) + + # Count all variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # obj, arg1, arg2 + assert var_count == 3 + end + + test "maps MethodCall nodes" do + # file:read("*a") + ast = + chunk([ + call_stmt(method_call(var("file"), "read", [string("*a")])) + ]) + + # Transform method name + transformed = + Walker.map(ast, fn + %Expr.MethodCall{method: m} = node -> %{node | method: "new_" <> m} + node -> node + end) + + # Extract method call + method_calls = + Walker.reduce(transformed, [], fn + %Expr.MethodCall{method: m}, acc -> [m | acc] + _, acc -> acc + end) + + assert method_calls == ["new_read"] + end + + test "walks Index nodes" do + # t[key] + ast = chunk([assign([index(var("t"), var("key"))], [number(42)])]) + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["key", "t"] + end + + test "maps Index nodes" do + # arr[1] = arr[2] + ast = + chunk([ + assign( + [index(var("arr"), number(1))], + [index(var("arr"), number(2))] + ) + ]) + + # Double all indices + transformed = + Walker.map(ast, fn + %Expr.Index{key: %Expr.Number{value: n} = key} = node -> + %{node | key: %{key | value: n * 2}} + + node -> + node + end) + + # Collect indices + indices = + Walker.reduce(transformed, [], fn + %Expr.Index{key: %Expr.Number{value: n}}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(indices) == [2, 4] + end + + test "walks Property nodes" do + # io.write + ast = chunk([call_stmt(call(property(var("io"), "write"), [string("test")]))]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # io + assert var_count == 1 + end + + test "maps Property nodes" do + # math.pi + ast = chunk([assign([var("x")], [property(var("math"), "pi")])]) + + # Transform property field + transformed = + Walker.map(ast, fn + %Expr.Property{field: f} = node -> %{node | field: String.upcase(f)} + node -> node + end) + + # Extract property field + fields = + Walker.reduce(transformed, [], fn + %Expr.Property{field: f}, acc -> [f | acc] + _, acc -> acc + end) + + assert fields == ["PI"] + end + + test "walks String nodes" do + # local s = "hello" + ast = chunk([local(["s"], [string("hello")])]) + + # Collect strings + strings = + Walker.reduce(ast, [], fn + %Expr.String{value: s}, acc -> [s | acc] + _, acc -> acc + end) + + assert strings == ["hello"] + end + + test "maps String nodes" do + # print("hello", "world") + ast = chunk([call_stmt(call(var("print"), [string("hello"), string("world")]))]) + + # Uppercase all strings + transformed = + Walker.map(ast, fn + %Expr.String{value: s} = node -> %{node | value: String.upcase(s)} + node -> node + end) + + strings = + Walker.reduce(transformed, [], fn + %Expr.String{value: s}, acc -> [s | acc] + _, acc -> acc + end) + + assert Enum.sort(strings) == ["HELLO", "WORLD"] + end + + test "walks Nil nodes" do + # local x = nil + ast = chunk([local(["x"], [nil_lit()])]) + + # Count nil literals + nil_count = + Walker.reduce(ast, 0, fn + %Expr.Nil{}, acc -> acc + 1 + _, acc -> acc + end) + + assert nil_count == 1 + end + + test "walks Vararg nodes" do + # function(...) return ... end + ast = chunk([func_decl("f", [], [return_stmt([vararg()])], vararg: true)]) + + # Count vararg expressions + vararg_count = + Walker.reduce(ast, 0, fn + %Expr.Vararg{}, acc -> acc + 1 + _, acc -> acc + end) + + assert vararg_count == 1 + end + + test "maps Vararg nodes in function" do + # local f = function(...) return ... end + ast = chunk([local(["f"], [function_expr([], [return_stmt([vararg()])], vararg: true)])]) + + # Count nodes before and after map + count_before = count_nodes(ast) + + transformed = + Walker.map(ast, fn + node -> node + end) + + count_after = count_nodes(transformed) + + # Structure should be preserved + assert count_before == count_after + end + end + + describe "statement nodes" do + test "walks Local without values" do + # local x, y + ast = chunk([local(["x", "y"], [])]) + + # Count local statements + local_count = + Walker.reduce(ast, 0, fn + %Stmt.Local{}, acc -> acc + 1 + _, acc -> acc + end) + + assert local_count == 1 + + # Should have no child expressions + expr_count = + Walker.reduce(ast, 0, fn + %Expr.Number{}, acc -> acc + 1 + _, acc -> acc + end) + + assert expr_count == 0 + end + + test "maps Local without values" do + # local x + ast = chunk([local(["x"], [])]) + + # Transform should preserve empty values list + transformed = + Walker.map(ast, fn + %Stmt.Local{values: []} = node -> node + node -> node + end) + + # Extract local statement + locals = + Walker.reduce(transformed, [], fn + %Stmt.Local{names: names, values: values}, acc -> [{names, values} | acc] + _, acc -> acc + end) + + assert locals == [{["x"], []}] + end + + test "walks LocalFunc nodes" do + # local function add(a, b) return a + b end + ast = + chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # a, b (in return statement) + assert var_count == 2 + end + + test "maps LocalFunc nodes" do + # local function f() return 1 end + ast = chunk([local_func("f", [], [return_stmt([number(1)])])]) + + # Double numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [2] + end + + test "walks While nodes" do + # while x > 0 do x = x - 1 end + ast = + chunk([ + while_stmt( + binop(:gt, var("x"), number(0)), + [assign([var("x")], [binop(:sub, var("x"), number(1))])] + ) + ]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{name: "x"}, acc -> acc + 1 + _, acc -> acc + end) + + # x appears 3 times: condition, target, value + assert var_count == 3 + end + + test "maps While nodes" do + # while true do print(1) end + ast = chunk([while_stmt(bool(true), [call_stmt(call(var("print"), [number(1)]))])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [11] + end + + test "walks Repeat nodes" do + # repeat x = x - 1 until x <= 0 + ast = + chunk([ + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + ]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{name: "x"}, acc -> acc + 1 + _, acc -> acc + end) + + # x appears 3 times: target, value, condition + assert var_count == 3 + end + + test "maps Repeat nodes" do + # repeat print(5) until true + ast = chunk([repeat_stmt([call_stmt(call(var("print"), [number(5)]))], bool(true))]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [10] + end + + test "walks ForNum with nil step" do + # for i = 1, 10 do print(i) end + ast = + chunk([for_num("i", number(1), number(10), [call_stmt(call(var("print"), [var("i")]))])]) + + # Verify step is nil + step_is_nil = + Walker.reduce(ast, false, fn + %Stmt.ForNum{step: nil}, _acc -> true + _, acc -> acc + end) + + assert step_is_nil + end + + test "walks ForNum with explicit step" do + # for i = 1, 10, 2 do print(i) end + ast = + chunk([ + for_num("i", number(1), number(10), [call_stmt(call(var("print"), [var("i")]))], + step: number(2) + ) + ]) + + # Count numbers (start, limit, step) + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # 1, 10, 2 from the loop header (i is a var in the body) + assert Enum.sort(numbers) == [1, 2, 10] + end + + test "maps ForNum with nil step" do + # for i = 1, 10 do end + ast = chunk([for_num("i", number(1), number(10), [])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 5} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [6, 15] + end + + test "maps ForNum with explicit step" do + # for i = 2, 20, 3 do end + ast = chunk([for_num("i", number(2), number(20), [], step: number(3))]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [20, 30, 200] + end + + test "walks ForIn nodes" do + # for k, v in pairs(t) do print(k, v) end + ast = + chunk([ + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + ]) + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + # pairs, t, print, k, v + assert Enum.sort(vars) == ["k", "pairs", "print", "t", "v"] + end + + test "maps ForIn nodes" do + # for x in iter() do print(x) end + ast = + chunk([ + for_in(["x"], [call(var("iter"), [])], [call_stmt(call(var("print"), [var("x")]))]) + ]) + + # Transform variable names + transformed = + Walker.map(ast, fn + %Expr.Var{name: name} = node -> %{node | name: "new_" <> name} + node -> node + end) + + vars = + Walker.reduce(transformed, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["new_iter", "new_print", "new_x"] + end + + test "walks Do nodes" do + # do local x = 1; print(x) end + ast = + chunk([ + do_block([ + local(["x"], [number(1)]), + call_stmt(call(var("print"), [var("x")])) + ]) + ]) + + # Count do statements + do_count = + Walker.reduce(ast, 0, fn + %Stmt.Do{}, acc -> acc + 1 + _, acc -> acc + end) + + assert do_count == 1 + end + + test "maps Do nodes" do + # do print(5) end + ast = chunk([do_block([call_stmt(call(var("print"), [number(5)]))])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 3} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [15] + end + + test "walks Break nodes" do + # while true do break end + ast = chunk([while_stmt(bool(true), [break_stmt()])]) + + # Count break statements + break_count = + Walker.reduce(ast, 0, fn + %Stmt.Break{}, acc -> acc + 1 + _, acc -> acc + end) + + assert break_count == 1 + end + + test "maps Break nodes (leaf node)" do + # while true do break end + ast = chunk([while_stmt(bool(true), [break_stmt()])]) + + # Map should preserve break + transformed = + Walker.map(ast, fn + node -> node + end) + + break_count = + Walker.reduce(transformed, 0, fn + %Stmt.Break{}, acc -> acc + 1 + _, acc -> acc + end) + + assert break_count == 1 + end + + test "walks Goto nodes" do + # goto skip + ast = chunk([goto_stmt("skip")]) + + # Count goto statements + goto_labels = + Walker.reduce(ast, [], fn + %Stmt.Goto{label: label}, acc -> [label | acc] + _, acc -> acc + end) + + assert goto_labels == ["skip"] + end + + test "maps Goto nodes (leaf node)" do + # goto target + ast = chunk([goto_stmt("target")]) + + # Map should preserve goto + transformed = + Walker.map(ast, fn + node -> node + end) + + labels = + Walker.reduce(transformed, [], fn + %Stmt.Goto{label: label}, acc -> [label | acc] + _, acc -> acc + end) + + assert labels == ["target"] + end + + test "walks Label nodes" do + # ::start:: + ast = chunk([label("start")]) + + # Count labels + labels = + Walker.reduce(ast, [], fn + %Stmt.Label{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert labels == ["start"] + end + + test "maps Label nodes (leaf node)" do + # ::loop:: + ast = chunk([label("loop")]) + + # Map should preserve label + transformed = + Walker.map(ast, fn + node -> node + end) + + labels = + Walker.reduce(transformed, [], fn + %Stmt.Label{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert labels == ["loop"] + end + + test "walks FuncDecl nodes" do + # function add(a, b) return a + b end + ast = + chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # a, b (in return statement) + assert var_count == 2 + end + + test "maps FuncDecl nodes" do + # function f() return 1 end + ast = chunk([func_decl("f", [], [return_stmt([number(1)])])]) + + # Double numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [2] + end + end + + describe "edge cases" do + test "If with no elseifs or else" do + # if x then print(x) end + ast = chunk([if_stmt(var("x"), [call_stmt(call(var("print"), [var("x")]))])]) + + # Verify structure + if_stmts = + Walker.reduce(ast, [], fn + %Stmt.If{elseifs: elseifs, else_block: else_block}, acc -> + [{elseifs, else_block} | acc] + + _, acc -> + acc + end) + + assert if_stmts == [{[], nil}] + end + + test "If with elseif clauses mapping" do + # if x > 0 then return 1 elseif x < 0 then return -1 end + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([number(-1)])]}] + ) + ]) + + # Map should traverse elseif conditions and blocks + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # Should have transformed: 0, 1, 0, -1 -> 0, 10, 0, -10 + assert Enum.sort(numbers) == [-10, 0, 0, 10] + end + + test "UnOp expressions mapping" do + # local x = -5 + ast = chunk([local(["x"], [unop(:neg, number(5))])]) + + # Map should traverse unary operations + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [10] + end + + test "Local without values mapping" do + # local x + ast = chunk([local(["x"], nil)]) + + # Map should handle Local without values + transformed = + Walker.map(ast, fn + node -> node + end) + + locals = + Walker.reduce(transformed, [], fn + %Stmt.Local{names: names, values: values}, acc -> [{names, values} | acc] + _, acc -> acc + end) + + assert locals == [{["x"], nil}] + end + + test "Table with list fields" do + # {1, 2, 3} + ast = + chunk([ + local(["t"], [ + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + ]) + ]) + + # Count numbers + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [1, 2, 3] + end + + test "Table with mixed list and record fields" do + # {10, 20, x = 30} + ast = + chunk([ + local(["t"], [ + table([ + {:list, number(10)}, + {:list, number(20)}, + {:record, string("x"), number(30)} + ]) + ]) + ]) + + # Map should handle both field types + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 1} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [11, 21, 31] + end + + test "Empty table" do + # local t = {} + ast = chunk([local(["t"], [table([])])]) + + # Count table nodes + table_count = + Walker.reduce(ast, 0, fn + %Expr.Table{}, acc -> acc + 1 + _, acc -> acc + end) + + assert table_count == 1 + + # Should have no field children + field_count = + Walker.reduce(ast, 0, fn + %Expr.Table{fields: fields}, acc -> acc + length(fields) + _, acc -> acc + end) + + assert field_count == 0 + end + + test "Nested method calls" do + # obj:method1():method2() + ast = + chunk([ + call_stmt( + method_call( + method_call(var("obj"), "method1", []), + "method2", + [] + ) + ) + ]) + + # Count method calls + method_count = + Walker.reduce(ast, 0, fn + %Expr.MethodCall{}, acc -> acc + 1 + _, acc -> acc + end) + + assert method_count == 2 + end + + test "Complex nested indexing" do + # t[a][b][c] + ast = + chunk([ + assign( + [index(index(index(var("t"), var("a")), var("b")), var("c"))], + [number(1)] + ) + ]) + + # Count index operations + index_count = + Walker.reduce(ast, 0, fn + %Expr.Index{}, acc -> acc + 1 + _, acc -> acc + end) + + assert index_count == 3 + end + + test "Multiple return values" do + # return a, b, c + ast = chunk([return_stmt([var("a"), var("b"), var("c")])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + assert var_count == 3 + end + + test "CallStmt with MethodCall" do + # obj:method() + ast = chunk([call_stmt(method_call(var("obj"), "method", []))]) + + # Should walk through CallStmt to MethodCall + call_stmt_count = + Walker.reduce(ast, 0, fn + %Stmt.CallStmt{}, acc -> acc + 1 + _, acc -> acc + end) + + method_call_count = + Walker.reduce(ast, 0, fn + %Expr.MethodCall{}, acc -> acc + 1 + _, acc -> acc + end) + + assert call_stmt_count == 1 + assert method_call_count == 1 + end + + test "Deeply nested expressions" do + # ((a + b) * (c + d)) / ((e - f) * (g - h)) + ast = + chunk([ + local(["x"], [ + binop( + :div, + binop(:mul, binop(:add, var("a"), var("b")), binop(:add, var("c"), var("d"))), + binop(:mul, binop(:sub, var("e"), var("f")), binop(:sub, var("g"), var("h"))) + ) + ]) + ]) + + # Count binary operations + binop_count = + Walker.reduce(ast, 0, fn + %Expr.BinOp{}, acc -> acc + 1 + _, acc -> acc + end) + + # 7 binary operations total + assert binop_count == 7 + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["a", "b", "c", "d", "e", "f", "g", "h"] + end + end + # Helper to count nodes defp count_nodes(ast) do Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) diff --git a/test/lua/lexer_test.exs b/test/lua/lexer_test.exs index e59c107..93b8739 100644 --- a/test/lua/lexer_test.exs +++ b/test/lua/lexer_test.exs @@ -13,20 +13,20 @@ defmodule Lua.LexerTest do :else, :elseif, :end, - :false, + false, :for, :function, :goto, :if, :in, :local, - :nil, + nil, :not, :or, :repeat, :return, :then, - :true, + true, :until, :while ] @@ -94,6 +94,36 @@ defmodule Lua.LexerTest do assert num == 3.0e2 end + test "tokenizes floats with scientific notation" do + # Float with exponent + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("2.5e3") + assert num == 2.5e3 + + # Float with uppercase E + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1.0E10") + assert num == 1.0e10 + + # Integer with exponent (becomes float) + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("5e2") + assert num == 5.0e2 + + # Exponent without sign + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1e5") + assert num == 1.0e5 + end + + test "handles edge cases in scientific notation" do + # Exponent without digits - should result in error + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e") + + # Exponent with sign but no digits - should result in error + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e+") + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e-") + + # Float with exponent without digits + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1.5e") + end + test "handles trailing dot correctly" do # "42." should be tokenized as number 42 followed by dot operator # But in Lua, "42." is actually a valid number (42.0) @@ -119,7 +149,9 @@ defmodule Lua.LexerTest do end test "handles escape sequences in strings" do - assert {:ok, [{:string, "hello\nworld", _}, {:eof, _}]} = Lexer.tokenize(~s("hello\\nworld")) + assert {:ok, [{:string, "hello\nworld", _}, {:eof, _}]} = + Lexer.tokenize(~s("hello\\nworld")) + assert {:ok, [{:string, "tab\there", _}, {:eof, _}]} = Lexer.tokenize(~s("tab\\there")) assert {:ok, [{:string, "quote\"here", _}, {:eof, _}]} = @@ -129,6 +161,25 @@ defmodule Lua.LexerTest do Lexer.tokenize(~s("backslash\\\\here")) end + test "handles all standard escape sequences" do + # Test \a (bell) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\a")) + # Test \b (backspace) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\b")) + # Test \f (form feed) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\f")) + # Test \r (carriage return) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\r")) + # Test \v (vertical tab) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\v")) + # Test \' (single quote) + assert {:ok, [{:string, "'", _}, {:eof, _}]} = Lexer.tokenize(~s("\\'")) + # Test \" (double quote) + assert {:ok, [{:string, "\"", _}, {:eof, _}]} = Lexer.tokenize(~s("\\"")) + # Test \\ (backslash) + assert {:ok, [{:string, "\\", _}, {:eof, _}]} = Lexer.tokenize(~s("\\\\")) + end + test "tokenizes long strings with [[...]]" do assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("[[hello]]") assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize("[[]]") @@ -143,6 +194,16 @@ defmodule Lua.LexerTest do assert {:ok, [{:string, "a]b", _}, {:eof, _}]} = Lexer.tokenize("[=[a]b]=]") end + test "long strings with false closing brackets" do + # ] not followed by the right number of = + assert {:ok, [{:string, " test ] more ", _}, {:eof, _}]} = + Lexer.tokenize("[=[ test ] more ]=]") + + # ] not followed by ] + assert {:ok, [{:string, " test ]= more ", _}, {:eof, _}]} = + Lexer.tokenize("[[ test ]= more ]]") + end + test "reports error for unclosed string" do assert {:error, {:unclosed_string, _}} = Lexer.tokenize(~s("hello)) assert {:error, {:unclosed_string, _}} = Lexer.tokenize("'hello") @@ -245,6 +306,24 @@ defmodule Lua.LexerTest do test "reports error for unclosed multi-line comment" do assert {:error, {:unclosed_comment, _}} = Lexer.tokenize("--[[ unclosed comment") end + + test "handles false closing brackets in multi-line comments" do + # Test a ] that is not followed by the right number of = + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ test ] more ]=]") + # Test a ] that is not followed by ] + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ test ]= more ]]") + end + + test "multi-line comment with newlines" do + code = "--[[ line 1\nline 2\nline 3 ]]" + assert {:ok, [{:eof, _}]} = Lexer.tokenize(code) + end + + test "multi-line comment level 0" do + # Test the actual --[[ path + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ comment ]]") + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("--[[ comment ]]x") + end end describe "whitespace" do @@ -258,6 +337,26 @@ defmodule Lua.LexerTest do assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r\n2") assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r2") end + + test "handles different newline styles in code" do + # CRLF newline + assert {:ok, tokens} = Lexer.tokenize("x\r\ny") + + assert [ + {:identifier, "x", _}, + {:identifier, "y", _}, + {:eof, _} + ] = tokens + + # CR only newline + assert {:ok, tokens} = Lexer.tokenize("x\ry") + + assert [ + {:identifier, "x", _}, + {:identifier, "y", _}, + {:eof, _} + ] = tokens + end end describe "position tracking" do @@ -409,6 +508,121 @@ defmodule Lua.LexerTest do assert {:ok, [{:operator, :gt, _}, {:operator, :assign, _}, {:eof, _}]} = Lexer.tokenize("> =") end + + test "handles invalid escape sequences in strings" do + # Invalid escape sequences should be included as-is + assert {:ok, [{:string, "\\x", _}, {:eof, _}]} = Lexer.tokenize(~s("\\x")) + assert {:ok, [{:string, "\\z", _}, {:eof, _}]} = Lexer.tokenize(~s("\\z")) + assert {:ok, [{:string, "\\1", _}, {:eof, _}]} = Lexer.tokenize(~s("\\1")) + end + + test "reports error for string with unescaped newline" do + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("\"hello\n") + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("'hello\n") + end + + test "handles trailing dot after number" do + # "42." should tokenize as number 42 followed by dot + assert {:ok, tokens} = Lexer.tokenize("42.") + assert [{:number, 42, _}, {:delimiter, :dot, _}, {:eof, _}] = tokens + end + + test "handles decimal point without following digit" do + # "42.x" should be number 42 followed by dot and identifier x + assert {:ok, tokens} = Lexer.tokenize("42.x") + assert [{:number, 42, _}, {:delimiter, :dot, _}, {:identifier, "x", _}, {:eof, _}] = tokens + end + + test "reports error for invalid hex number" do + assert {:error, {:invalid_hex_number, _}} = Lexer.tokenize("0x") + assert {:error, {:invalid_hex_number, _}} = Lexer.tokenize("0xg") + end + + test "handles uppercase X in hex numbers" do + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0XFF") + assert {:ok, [{:number, 10, _}, {:eof, _}]} = Lexer.tokenize("0Xa") + end + + test "single-line comment ending with LF" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\nx") + end + + test "single-line comment ending with CR" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\rx") + end + + test "single-line comment ending with CRLF" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\r\nx") + end + + test "single-line comment at end of file" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- comment at EOF") + end + + test "comment starting with --[ but not --[[" do + # This should be treated as a single-line comment, not a multi-line comment + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[ this is single line") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[= not multi-line") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[x not multi-line") + end + + test "multi-line comment with mismatched bracket level" do + # The closing bracket doesn't match the opening level + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ comment ]]") + # This should continue scanning until EOF because ]] doesn't match the opening [=[ + end + + test "long string with mismatched closing bracket" do + # Opening [=[ but closing with ]] + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[=[ string ]]") + end + + test "long bracket not actually a long bracket" do + # "[" followed by something other than "=" or "[" should be treated as delimiter + assert {:ok, [{:delimiter, :lbracket, _}, {:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("[x") + end + + test "right bracket delimiter" do + assert {:ok, [{:delimiter, :rbracket, _}, {:eof, _}]} = Lexer.tokenize("]") + end + end + + describe "additional edge cases for coverage" do + test "whitespace at start with CRLF" do + # Test CRLF at very start of file + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("\r\nx") + end + + test "whitespace at start with CR" do + # Test CR at very start of file + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("\rx") + end + + test "single-line comment with --[ at start" do + # Ensure --[ path is taken + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[not a multiline") + end + + test "multiline comment with --[[ at start" do + # Ensure --[[ path is taken + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ multiline ]]") + end + + test "long string starting at beginning" do + assert {:ok, [{:string, "test", _}, {:eof, _}]} = Lexer.tokenize("[[test]]") + end + + test "number followed by concat operator" do + # Test the path where we have ".." after a number + assert {:ok, [{:number, 5, _}, {:operator, :concat, _}, {:eof, _}]} = + Lexer.tokenize("5..") + end + + test "uppercase E in scientific notation for integer" do + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("2E5") + assert num == 2.0e5 + end end describe "real Lua code examples" do diff --git a/test/lua/parser/beautiful_errors_test.exs b/test/lua/parser/beautiful_errors_test.exs deleted file mode 100644 index 02bbabb..0000000 --- a/test/lua/parser/beautiful_errors_test.exs +++ /dev/null @@ -1,269 +0,0 @@ -defmodule Lua.Parser.BeautifulErrorsTest do - use ExUnit.Case, async: true - alias Lua.Parser - - @moduletag :beautiful_errors - - describe "beautiful error message demonstrations" do - test "missing 'end' keyword shows context and suggestion" do - code = """ - function factorial(n) - if n <= 1 then - return 1 - else - return n * factorial(n - 1) - -- Missing 'end' here! - """ - - assert {:error, msg} = Parser.parse(code) - - # Check for essential components - assert msg =~ ~r/Parse Error/i - assert msg =~ "line" - assert msg =~ "Expected" - assert msg =~ "end" - - # Check for visual formatting - assert msg =~ "│" # Line separator - assert msg =~ "^" # Error pointer - - # Should have ANSI color codes - assert msg =~ "\e[" - - # Print for manual inspection during test runs - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 1: Missing 'end' keyword") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "missing 'then' keyword provides helpful suggestion" do - code = """ - if x > 0 - print(x) - end - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - assert msg =~ "Expected" - assert msg =~ ":then" - assert msg =~ "line 2" - - # Should show the problematic line - assert msg =~ "print(x)" - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 2: Missing 'then' keyword") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "unclosed string shows line with error pointer" do - code = """ - local message = "Hello, World! - print(message) - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - assert msg =~ "Unclosed string" - assert msg =~ "line 1" - - # Should show suggestion - assert msg =~ "Suggestion" - assert msg =~ "closing quote" - - # Should show the unclosed string line - assert msg =~ ~s(local message = "Hello, World!) - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 3: Unclosed string") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "missing closing parenthesis shows context" do - code = """ - local function test(a, b - return a + b - end - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - assert msg =~ "Expected" - assert msg =~ ":rparen" - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 4: Missing closing parenthesis") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "invalid character shows clear message" do - code = """ - local x = 42 - local y = @invalid - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - assert msg =~ "Unexpected character" - assert msg =~ "line 2" - assert msg =~ "@" - - # Should have suggestion - assert msg =~ "Suggestion" - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 5: Invalid character") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "missing 'do' in while loop" do - code = """ - while x > 0 - x = x - 1 - end - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - assert msg =~ "Expected" - assert msg =~ ":do" - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 6: Missing 'do' in while loop") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "complex error with multiple context lines" do - code = """ - function complex_function() - local x = 10 - local y = 20 - if x > y then - return x - -- Missing 'end' for if - return y - -- Missing 'end' for function - """ - - assert {:error, msg} = Parser.parse(code) - - assert msg =~ "Parse Error" - - # Error is at EOF (line 9), so context shows lines around line 9 - # Should show the lines that are actually in the context window (lines 7-9) - assert msg =~ "return y" - assert msg =~ "-- Missing 'end' for function" - - if System.get_env("SHOW_ERRORS") do - IO.puts("\n" <> String.duplicate("=", 70)) - IO.puts("Example 7: Complex error with context") - IO.puts(String.duplicate("=", 70)) - IO.puts(msg) - IO.puts(String.duplicate("=", 70) <> "\n") - end - end - - test "error message formatting has proper structure" do - code = "if x then" - - assert {:error, msg} = Parser.parse(code) - - # Check structure components - assert msg =~ "Parse Error" - assert msg =~ "at line" - assert msg =~ "column" - - # Check visual elements - assert msg =~ "│" # Box drawing character for line separator - assert msg =~ "^" # Pointer to error location - - # Check color codes (ANSI escape sequences) - assert String.contains?(msg, "\e[31m") # Red color for error - assert String.contains?(msg, "\e[0m") # Reset color - end - end - - describe "error message quality checks" do - test "always includes line and column information when available" do - code = """ - local x = 1 - if x > 0 then - print(x - end - """ - - assert {:error, msg} = Parser.parse(code) - assert msg =~ ~r/line \d+/ - assert msg =~ ~r/column \d+/ - end - - test "always includes visual pointer to error location" do - code = "local x = +" - - assert {:error, msg} = Parser.parse(code) - assert msg =~ "^" # Caret pointer - end - - test "shows surrounding context lines" do - code = """ - line1 = 1 - line2 = 2 - if x then - line4 = 4 - line5 = 5 - """ - - # This will fail to parse due to missing 'end' - {:error, msg} = Parser.parse(code) - - # Should show lines around the error with box drawing - assert msg =~ "│" # Line separator - assert msg =~ "line" # Should show context lines - end - - test "uses colors for better readability" do - code = "if x then" - - assert {:error, msg} = Parser.parse(code) - - # Red for errors - assert msg =~ "\e[31m" - # Bright/bold - assert msg =~ "\e[1m" - # Reset - assert msg =~ "\e[0m" - # Cyan for suggestions - assert msg =~ "\e[36m" - end - end -end diff --git a/test/lua/parser/error_test.exs b/test/lua/parser/error_test.exs index 15f268d..5bd15c5 100644 --- a/test/lua/parser/error_test.exs +++ b/test/lua/parser/error_test.exs @@ -1,31 +1,34 @@ defmodule Lua.Parser.ErrorTest do + @moduledoc """ + Tests for parser error messages, including formatting and suggestions. + """ use ExUnit.Case, async: true alias Lua.Parser - describe "beautiful error messages" do - test "missing 'end' keyword shows helpful message" do + describe "syntax errors" do + test "missing 'end' keyword" do code = """ function foo() return 1 """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Parse Error") - assert String.contains?(error_msg, "line 3") - assert String.contains?(error_msg, "Expected") - assert String.contains?(error_msg, "'end'") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + assert msg =~ "line 3" + assert msg =~ "Expected" + assert msg =~ "'end'" end - test "missing 'then' keyword provides suggestion" do + test "missing 'then' keyword" do code = """ if x > 0 return x end """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Expected") - assert String.contains?(error_msg, ":then") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Expected" + assert msg =~ ":then" end test "missing 'do' keyword in while loop" do @@ -35,52 +38,95 @@ defmodule Lua.Parser.ErrorTest do end """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Expected") - assert String.contains?(error_msg, ":do") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Expected" + assert msg =~ ":do" end - test "unclosed string shows context" do + test "missing closing parenthesis" do + code = "print(1, 2, 3" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "missing closing bracket" do + code = "local t = {1, 2, 3" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "unexpected token in expression" do + code = "local x = 1 + + 2" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "complex nested function with missing end" do code = """ - local x = "hello + function factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + -- Missing 'end' here! """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Unclosed string") - assert String.contains?(error_msg, "line 1") + assert {:error, msg} = Parser.parse(code) + assert msg =~ ~r/Parse Error/i + assert msg =~ "line" + assert msg =~ "Expected" + assert msg =~ "end" end + end - test "unexpected character shows position" do + describe "lexer errors" do + test "unclosed string" do + code = ~s(local x = "hello) + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Unclosed string" + assert msg =~ "line 1" + end + + test "unexpected character" do code = """ local x = 42 local y = @invalid """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Unexpected character") - assert String.contains?(error_msg, "line 2") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Unexpected character" + assert msg =~ "line 2" + assert msg =~ "@" end + end - test "missing closing parenthesis" do - code = """ - print(1, 2, 3 - """ - - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Parse Error") - # Should mention parenthesis or bracket + describe "error message formatting" do + test "includes visual formatting elements" do + code = "if x > 0" + + assert {:error, msg} = Parser.parse(code) + # Line separator + assert msg =~ "│" + # Error pointer + assert msg =~ "^" + # ANSI color codes + assert msg =~ "\e[" end - test "missing closing bracket" do + test "includes line and column information" do code = """ - local t = {1, 2, 3 + local x = 1 + if x > 0 then + print(x + end """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Parse Error") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "line" + assert msg =~ "column" end - test "shows context with line numbers" do + test "shows context lines around error" do code = """ local x = 1 local y = 2 @@ -89,82 +135,40 @@ defmodule Lua.Parser.ErrorTest do end """ - assert {:error, error_msg} = Parser.parse(code) - # Should show context around the error - assert String.contains?(error_msg, "│") - end - - test "unexpected token in expression" do - code = """ - local x = 1 + + 2 - """ - - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Parse Error") - end - - test "invalid syntax after valid code" do - code = """ - function add(a, b) - return a + b - end - - function multiply(x, y - return x * y - end - """ - - assert {:error, error_msg} = Parser.parse(code) - # Error is on line 6 (the return statement) because line 5 is missing closing ) - assert String.contains?(error_msg, "line 6") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "│" end - end - describe "error message formatting" do - test "formats with color codes for terminal" do - code = "if x then" - - assert {:error, error_msg} = Parser.parse(code) - # Color codes should be present (ANSI escape codes) - assert String.contains?(error_msg, "\e[") - end - - test "shows helpful suggestions" do + test "provides helpful suggestions" do code = """ function test() print("hello") """ - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "Suggestion") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Suggestion" end - test "includes line and column information" do - code = """ - local x = 1 - if x > 0 then - print(x - end - """ + test "uses ANSI colors for better readability" do + code = "if x then" - assert {:error, error_msg} = Parser.parse(code) - assert String.contains?(error_msg, "line") - assert String.contains?(error_msg, "column") + assert {:error, msg} = Parser.parse(code) + assert msg =~ "\e[31m" # Red for errors + assert msg =~ "\e[1m" # Bold + assert msg =~ "\e[0m" # Reset + assert msg =~ "\e[36m" # Cyan for suggestions end end - describe "raw error parsing" do - test "parse_raw returns structured error" do + describe "parse_raw API" do + test "returns structured error tuple" do code = "if x then" - assert {:error, error_tuple} = Parser.parse_raw(code) - # Should be a tuple, not a formatted string assert is_tuple(error_tuple) end - test "parse_raw successful parsing" do + test "returns AST on success" do code = "local x = 42" - assert {:ok, chunk} = Parser.parse_raw(code) assert chunk.__struct__ == Lua.AST.Chunk end diff --git a/test/lua/parser/error_unit_test.exs b/test/lua/parser/error_unit_test.exs new file mode 100644 index 0000000..c9d566d --- /dev/null +++ b/test/lua/parser/error_unit_test.exs @@ -0,0 +1,627 @@ +defmodule Lua.Parser.ErrorUnitTest do + use ExUnit.Case, async: true + + alias Lua.Parser.Error + + describe "new/4" do + test "creates error with all fields" do + position = %{line: 1, column: 5, byte_offset: 0} + related = [Error.new(:unexpected_token, "related error", nil)] + + error = + Error.new( + :unexpected_token, + "test message", + position, + suggestion: "test suggestion", + source_lines: ["line 1", "line 2"], + related: related + ) + + assert error.type == :unexpected_token + assert error.message == "test message" + assert error.position == position + assert error.suggestion == "test suggestion" + assert error.source_lines == ["line 1", "line 2"] + assert error.related == related + end + + test "creates error with minimal fields" do + error = Error.new(:expected_token, "minimal error") + + assert error.type == :expected_token + assert error.message == "minimal error" + assert error.position == nil + assert error.suggestion == nil + assert error.source_lines == [] + assert error.related == [] + end + + test "creates error with position but no opts" do + position = %{line: 2, column: 10, byte_offset: 15} + error = Error.new(:invalid_syntax, "with position", position) + + assert error.type == :invalid_syntax + assert error.message == "with position" + assert error.position == position + assert error.suggestion == nil + assert error.source_lines == [] + assert error.related == [] + end + end + + describe "unexpected_token/4" do + test "creates error for delimiter token" do + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.unexpected_token(:delimiter, ")", position, "expression") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "Unexpected ')'") + assert String.contains?(error.message, "in expression") + assert error.position == position + assert error.suggestion == "Check for missing operators or keywords before this delimiter" + end + + test "creates error for token in expression context" do + position = %{line: 2, column: 3, byte_offset: 10} + error = Error.unexpected_token(:keyword, "end", position, "primary expression") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "in primary expression") + assert error.position == position + + assert error.suggestion == + "Expected an expression here (variable, number, string, table, function, etc.)" + end + + test "creates error for token in statement context" do + position = %{line: 3, column: 1, byte_offset: 20} + error = Error.unexpected_token(:number, 42, position, "statement") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "in statement") + assert error.position == position + + assert error.suggestion == + "Expected a statement here (assignment, function call, if, while, for, etc.)" + end + + test "creates error with no specific suggestion" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.unexpected_token(:identifier, "foo", position, "block") + + assert error.type == :unexpected_token + assert error.position == position + assert error.suggestion == nil + end + end + + describe "expected_token/5" do + test "creates error for expected 'end' keyword" do + position = %{line: 10, column: 1, byte_offset: 100} + error = Error.expected_token(:keyword, :end, :eof, nil, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'end'") + assert String.contains?(error.message, "but got end of input") + assert error.position == position + + assert String.contains?( + error.suggestion, + "Add 'end' to close the block. Check that all opening keywords" + ) + end + + test "creates error for expected 'then' keyword" do + position = %{line: 5, column: 20, byte_offset: 50} + error = Error.expected_token(:keyword, :then, :identifier, "x", position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'then'") + assert String.contains?(error.message, "but got identifier 'x'") + + assert String.contains?( + error.suggestion, + "Add 'then' after the condition. Lua requires 'then' after if/elseif conditions." + ) + end + + test "creates error for expected 'do' keyword" do + position = %{line: 7, column: 15, byte_offset: 75} + error = Error.expected_token(:keyword, :do, :keyword, :end, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'do'") + assert String.contains?(error.message, "but got 'end'") + + assert String.contains?( + error.suggestion, + "Add 'do' to start the loop body. Lua requires 'do' after while/for conditions." + ) + end + + test "creates error for expected ')' delimiter" do + position = %{line: 3, column: 12, byte_offset: 30} + error = Error.expected_token(:delimiter, :rparen, :delimiter, :comma, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rparen'") + assert String.contains?(error.message, "but got 'comma'") + assert String.contains?(error.suggestion, "Add ')' to close the parentheses") + end + + test "creates error for expected ']' delimiter" do + position = %{line: 4, column: 8, byte_offset: 40} + error = Error.expected_token(:delimiter, :rbracket, :eof, nil, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rbracket'") + assert String.contains?(error.message, "but got end of input") + assert String.contains?(error.suggestion, "Add ']' to close the brackets") + end + + test "creates error for expected '}' delimiter" do + position = %{line: 6, column: 5, byte_offset: 60} + error = Error.expected_token(:delimiter, :rbrace, :keyword, :end, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rbrace'") + assert String.contains?(error.message, "but got 'end'") + assert String.contains?(error.suggestion, "Add '}' to close the table constructor") + end + + test "creates error for expected assignment operator" do + position = %{line: 2, column: 7, byte_offset: 15} + error = Error.expected_token(:operator, :assign, :number, 42, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected operator 'assign'") + assert String.contains?(error.message, "but got number 42") + assert String.contains?(error.suggestion, "Add '=' for assignment") + end + + test "creates error for identifier when got keyword" do + position = %{line: 8, column: 10, byte_offset: 80} + error = Error.expected_token(:identifier, nil, :keyword, :if, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected identifier") + assert String.contains?(error.message, "but got 'if'") + assert String.contains?(error.suggestion, "Cannot use Lua keyword as identifier") + end + + test "creates error with no specific suggestion" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:operator, :plus, :operator, :minus, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected operator 'plus'") + assert String.contains?(error.message, "but got operator 'minus'") + assert error.suggestion == nil + end + end + + describe "unclosed_delimiter/3" do + test "creates error for unclosed lparen" do + position = %{line: 3, column: 5, byte_offset: 25} + error = Error.unclosed_delimiter(:lparen, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening parenthesis '('") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing ')'") + assert String.contains?(error.suggestion, "line 3") + end + + test "creates error for unclosed lbracket" do + position = %{line: 5, column: 2, byte_offset: 50} + error = Error.unclosed_delimiter(:lbracket, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening bracket '['") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing ']'") + assert String.contains?(error.suggestion, "line 5") + end + + test "creates error for unclosed lbrace" do + position = %{line: 7, column: 10, byte_offset: 75} + error = Error.unclosed_delimiter(:lbrace, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening brace '{'") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing '}'") + assert String.contains?(error.suggestion, "line 7") + end + + test "creates error for unclosed function block" do + position = %{line: 10, column: 1, byte_offset: 100} + error = Error.unclosed_delimiter(:function, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'function' block") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 10") + end + + test "creates error for unclosed if statement" do + position = %{line: 12, column: 1, byte_offset: 120} + error = Error.unclosed_delimiter(:if, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'if' statement") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 12") + end + + test "creates error for unclosed while loop" do + position = %{line: 15, column: 1, byte_offset: 150} + error = Error.unclosed_delimiter(:while, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'while' loop") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 15") + end + + test "creates error for unclosed for loop" do + position = %{line: 18, column: 1, byte_offset: 180} + error = Error.unclosed_delimiter(:for, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'for' loop") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 18") + end + + test "creates error for unclosed do block" do + position = %{line: 20, column: 1, byte_offset: 200} + error = Error.unclosed_delimiter(:do, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'do' block") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 20") + end + + test "creates error for unknown delimiter (default case)" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.unclosed_delimiter(:unknown_delim, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed unknown_delim") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing matching delimiter") + assert String.contains?(error.suggestion, "line 1") + end + + test "creates error with close position different from open position" do + open_pos = %{line: 3, column: 5, byte_offset: 25} + close_pos = %{line: 10, column: 1, byte_offset: 100} + error = Error.unclosed_delimiter(:lparen, open_pos, close_pos) + + assert error.type == :unclosed_delimiter + assert error.position == close_pos + assert String.contains?(error.suggestion, "line 3") + end + end + + describe "unexpected_end/2" do + test "creates error with position" do + position = %{line: 5, column: 1, byte_offset: 50} + error = Error.unexpected_end("function body", position) + + assert error.type == :unexpected_end + assert String.contains?(error.message, "Unexpected end of input") + assert String.contains?(error.message, "while parsing function body") + assert error.position == position + + assert String.contains?( + error.suggestion, + "Check for missing closing delimiters or keywords like 'end'" + ) + end + + test "creates error without position" do + error = Error.unexpected_end("expression") + + assert error.type == :unexpected_end + assert String.contains?(error.message, "Unexpected end of input") + assert String.contains?(error.message, "while parsing expression") + assert error.position == nil + assert error.suggestion != nil + end + end + + describe "format_token/2 (via expected_token)" do + test "formats keyword token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:keyword, :if, :keyword, :while, position) + + assert String.contains?(error.message, "'if'") + assert String.contains?(error.message, "'while'") + end + + test "formats identifier token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:identifier, "foo", :identifier, "bar", position) + + assert String.contains?(error.message, "identifier 'foo'") + assert String.contains?(error.message, "identifier 'bar'") + end + + test "formats number token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:number, 123, :number, 456, position) + + assert String.contains?(error.message, "number 123") + assert String.contains?(error.message, "number 456") + end + + test "formats string token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:string, "hello", :string, "world", position) + + assert String.contains?(error.message, "string \"hello\"") + assert String.contains?(error.message, "string \"world\"") + end + + test "formats operator token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:operator, :plus, :operator, :minus, position) + + assert String.contains?(error.message, "operator 'plus'") + assert String.contains?(error.message, "operator 'minus'") + end + + test "formats delimiter token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:delimiter, "(", :delimiter, ")", position) + + assert String.contains?(error.message, "'('") + assert String.contains?(error.message, "')'") + end + + test "formats eof token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:keyword, :end, :eof, nil, position) + + assert String.contains?(error.message, "end of input") + end + + test "formats unknown token type (default case)" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:unknown_type, "value", :other_type, "value2", position) + + assert String.contains?(error.message, "unknown_type") + assert String.contains?(error.message, "other_type") + end + end + + describe "format/2" do + test "formats error with position and context" do + source_code = """ + local x = 1 + local y = 2 + + local z = 3 + """ + + position = %{line: 2, column: 14, byte_offset: 26} + error = Error.new(:unexpected_token, "Unexpected token", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "at line 2, column 14") + assert String.contains?(formatted, "Unexpected token") + assert String.contains?(formatted, "local x = 1") + assert String.contains?(formatted, "local y = 2 +") + assert String.contains?(formatted, "local z = 3") + assert String.contains?(formatted, "^") + end + + test "formats error without position" do + source_code = "local x = 1" + error = Error.new(:unexpected_end, "Unexpected end") + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "(no position information)") + assert String.contains?(formatted, "Unexpected end") + refute String.contains?(formatted, "^") + end + + test "formats error with suggestion" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + + error = + Error.new(:expected_token, "Expected something", position, + suggestion: "Try adding a semicolon" + ) + + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Suggestion:") + assert String.contains?(formatted, "Try adding a semicolon") + end + + test "formats error without suggestion" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:invalid_syntax, "Invalid syntax", position) + formatted = Error.format(error, source_code) + + refute String.contains?(formatted, "Suggestion:") + end + + test "formats error with related errors" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + related_pos = %{line: 1, column: 10, byte_offset: 5} + + related = [Error.new(:unexpected_token, "Related issue", related_pos)] + error = Error.new(:expected_token, "Main error", position, related: related) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Related errors:") + assert String.contains?(formatted, "Related issue") + end + + test "formats error without related errors" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:unexpected_token, "Simple error", position) + formatted = Error.format(error, source_code) + + refute String.contains?(formatted, "Related errors:") + end + + test "formats error at start of file (line 1)" do + source_code = """ + function foo() + x = 1 + end + """ + + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.new(:unexpected_token, "Error at start", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "at line 1, column 1") + assert String.contains?(formatted, "function foo()") + assert String.contains?(formatted, "x = 1") + end + + test "formats error at end of file" do + source_code = """ + line 1 + line 2 + line 3 + line 4 + line 5 + """ + + position = %{line: 5, column: 6, byte_offset: 30} + error = Error.new(:unexpected_end, "Error at end", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "at line 5, column 6") + assert String.contains?(formatted, "line 3") + assert String.contains?(formatted, "line 4") + assert String.contains?(formatted, "line 5") + end + + test "formats error with empty source code" do + source_code = "" + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.new(:unexpected_end, "Empty file", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "Empty file") + # Even with empty source, it will show context if position exists + # The format_context function will handle empty lines gracefully + end + + test "formats error showing context lines" do + source_code = """ + line 1 + line 2 + line 3 + line 4 + line 5 + line 6 + line 7 + """ + + position = %{line: 4, column: 3, byte_offset: 21} + error = Error.new(:unexpected_token, "Error in middle", position) + formatted = Error.format(error, source_code) + + # Should show 2 lines before and after (lines 2, 3, 4, 5, 6) + assert String.contains?(formatted, "line 2") + assert String.contains?(formatted, "line 3") + assert String.contains?(formatted, "line 4") + assert String.contains?(formatted, "line 5") + assert String.contains?(formatted, "line 6") + refute String.contains?(formatted, "line 1\n") + refute String.contains?(formatted, "line 7\n") + end + + test "formats error with pointer at correct column" do + source_code = "local x = 1 + 2" + position = %{line: 1, column: 15, byte_offset: 14} + error = Error.new(:unexpected_token, "Error", position) + formatted = Error.format(error, source_code) + + lines = String.split(formatted, "\n") + error_line_idx = Enum.find_index(lines, &String.contains?(&1, "local x = 1 + 2")) + pointer_line = Enum.at(lines, error_line_idx + 1) + + # Pointer should be at column 15 + assert String.contains?(pointer_line, "^") + # Count spaces before ^ + spaces_before = pointer_line |> String.split("^") |> hd() |> String.length() + # Should account for line number formatting (4 digits + " │ " = 7 chars) + column - 1 + assert spaces_before >= 20 + end + end + + describe "format_multiple/2" do + test "formats single error" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:unexpected_token, "Error 1", position) + formatted = Error.format_multiple([error], source_code) + + assert String.contains?(formatted, "Found 1 parse error") + refute String.contains?(formatted, "1 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "Error 1") + end + + test "formats multiple errors" do + source_code = """ + local x = 1 + local y = 2 + local z = 3 + """ + + position1 = %{line: 1, column: 5, byte_offset: 0} + position2 = %{line: 2, column: 5, byte_offset: 12} + + error1 = Error.new(:unexpected_token, "First error", position1) + error2 = Error.new(:expected_token, "Second error", position2) + + formatted = Error.format_multiple([error1, error2], source_code) + + assert String.contains?(formatted, "Found 2 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "First error") + assert String.contains?(formatted, "Error 2:") + assert String.contains?(formatted, "Second error") + end + + test "formats three errors" do + source_code = "x" + position = %{line: 1, column: 1, byte_offset: 0} + + error1 = Error.new(:unexpected_token, "E1", position) + error2 = Error.new(:expected_token, "E2", position) + error3 = Error.new(:invalid_syntax, "E3", position) + + formatted = Error.format_multiple([error1, error2, error3], source_code) + + assert String.contains?(formatted, "Found 3 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "Error 2:") + assert String.contains?(formatted, "Error 3:") + end + end +end diff --git a/test/lua/parser/pratt_test.exs b/test/lua/parser/pratt_test.exs new file mode 100644 index 0000000..f5a2715 --- /dev/null +++ b/test/lua/parser/pratt_test.exs @@ -0,0 +1,253 @@ +defmodule Lua.Parser.PrattTest do + use ExUnit.Case, async: true + alias Lua.Parser.Pratt + + describe "binding_power/1" do + test "returns correct binding power for or operator" do + assert {1, 2} = Pratt.binding_power(:or) + end + + test "returns correct binding power for and operator" do + assert {3, 4} = Pratt.binding_power(:and) + end + + test "returns correct binding power for comparison operators" do + assert {5, 6} = Pratt.binding_power(:lt) + assert {5, 6} = Pratt.binding_power(:gt) + assert {5, 6} = Pratt.binding_power(:le) + assert {5, 6} = Pratt.binding_power(:ge) + assert {5, 6} = Pratt.binding_power(:ne) + assert {5, 6} = Pratt.binding_power(:eq) + end + + test "returns correct binding power for concat operator (right associative)" do + assert {7, 6} = Pratt.binding_power(:concat) + end + + test "returns correct binding power for additive operators" do + assert {9, 10} = Pratt.binding_power(:add) + assert {9, 10} = Pratt.binding_power(:sub) + end + + test "returns correct binding power for multiplicative operators" do + assert {11, 12} = Pratt.binding_power(:mul) + assert {11, 12} = Pratt.binding_power(:div) + assert {11, 12} = Pratt.binding_power(:floordiv) + assert {11, 12} = Pratt.binding_power(:mod) + end + + test "returns correct binding power for unary operators (should not be used as binary)" do + # These are unary operators, but the function includes them for completeness + # They should not be used as binary operators in practice + assert {13, 14} = Pratt.binding_power(:not) + assert {13, 14} = Pratt.binding_power(:neg) + assert {13, 14} = Pratt.binding_power(:len) + end + + test "returns correct binding power for power operator (right associative)" do + assert {16, 15} = Pratt.binding_power(:pow) + end + + test "returns nil for non-operators" do + assert nil == Pratt.binding_power(:invalid) + assert nil == Pratt.binding_power(:lparen) + assert nil == Pratt.binding_power(:rparen) + assert nil == Pratt.binding_power(:identifier) + end + end + + describe "prefix_binding_power/1" do + test "returns correct binding power for not operator" do + assert 14 = Pratt.prefix_binding_power(:not) + end + + test "returns correct binding power for unary minus (sub)" do + assert 13 = Pratt.prefix_binding_power(:sub) + end + + test "returns correct binding power for length operator" do + assert 14 = Pratt.prefix_binding_power(:len) + end + + test "returns nil for non-prefix operators" do + assert nil == Pratt.prefix_binding_power(:add) + assert nil == Pratt.prefix_binding_power(:mul) + assert nil == Pratt.prefix_binding_power(:invalid) + end + end + + describe "token_to_binop/1" do + test "maps logical operators" do + assert :or = Pratt.token_to_binop(:or) + assert :and = Pratt.token_to_binop(:and) + end + + test "maps comparison operators" do + assert :lt = Pratt.token_to_binop(:lt) + assert :gt = Pratt.token_to_binop(:gt) + assert :le = Pratt.token_to_binop(:le) + assert :ge = Pratt.token_to_binop(:ge) + assert :ne = Pratt.token_to_binop(:ne) + assert :eq = Pratt.token_to_binop(:eq) + end + + test "maps string concatenation operator" do + assert :concat = Pratt.token_to_binop(:concat) + end + + test "maps arithmetic operators" do + assert :add = Pratt.token_to_binop(:add) + assert :sub = Pratt.token_to_binop(:sub) + assert :mul = Pratt.token_to_binop(:mul) + assert :div = Pratt.token_to_binop(:div) + assert :floordiv = Pratt.token_to_binop(:floordiv) + assert :mod = Pratt.token_to_binop(:mod) + assert :pow = Pratt.token_to_binop(:pow) + end + + test "returns nil for non-binary operators" do + assert nil == Pratt.token_to_binop(:not) + assert nil == Pratt.token_to_binop(:len) + assert nil == Pratt.token_to_binop(:invalid) + end + end + + describe "token_to_unop/1" do + test "maps unary operators" do + assert :not = Pratt.token_to_unop(:not) + assert :neg = Pratt.token_to_unop(:sub) + assert :len = Pratt.token_to_unop(:len) + end + + test "returns nil for non-unary operators" do + assert nil == Pratt.token_to_unop(:add) + assert nil == Pratt.token_to_unop(:mul) + assert nil == Pratt.token_to_unop(:or) + assert nil == Pratt.token_to_unop(:invalid) + end + end + + describe "is_binary_op?/1" do + test "returns true for binary operators" do + assert Pratt.is_binary_op?(:or) + assert Pratt.is_binary_op?(:and) + assert Pratt.is_binary_op?(:lt) + assert Pratt.is_binary_op?(:gt) + assert Pratt.is_binary_op?(:le) + assert Pratt.is_binary_op?(:ge) + assert Pratt.is_binary_op?(:ne) + assert Pratt.is_binary_op?(:eq) + assert Pratt.is_binary_op?(:concat) + assert Pratt.is_binary_op?(:add) + assert Pratt.is_binary_op?(:sub) + assert Pratt.is_binary_op?(:mul) + assert Pratt.is_binary_op?(:div) + assert Pratt.is_binary_op?(:floordiv) + assert Pratt.is_binary_op?(:mod) + assert Pratt.is_binary_op?(:pow) + end + + test "returns false for non-binary operators" do + refute Pratt.is_binary_op?(:invalid) + refute Pratt.is_binary_op?(:lparen) + refute Pratt.is_binary_op?(:identifier) + end + end + + describe "is_prefix_op?/1" do + test "returns true for prefix operators" do + assert Pratt.is_prefix_op?(:not) + assert Pratt.is_prefix_op?(:sub) + assert Pratt.is_prefix_op?(:len) + end + + test "returns false for non-prefix operators" do + refute Pratt.is_prefix_op?(:add) + refute Pratt.is_prefix_op?(:mul) + refute Pratt.is_prefix_op?(:or) + refute Pratt.is_prefix_op?(:invalid) + end + end + + describe "operator precedence correctness" do + test "logical operators have lowest precedence" do + {or_left, _} = Pratt.binding_power(:or) + {and_left, _} = Pratt.binding_power(:and) + {lt_left, _} = Pratt.binding_power(:lt) + + assert or_left < and_left + assert and_left < lt_left + end + + test "comparison operators have same precedence" do + {lt_left, _} = Pratt.binding_power(:lt) + {gt_left, _} = Pratt.binding_power(:gt) + {le_left, _} = Pratt.binding_power(:le) + {ge_left, _} = Pratt.binding_power(:ge) + {ne_left, _} = Pratt.binding_power(:ne) + {eq_left, _} = Pratt.binding_power(:eq) + + assert lt_left == gt_left + assert lt_left == le_left + assert lt_left == ge_left + assert lt_left == ne_left + assert lt_left == eq_left + end + + test "concat is right associative (left_bp > right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:concat) + assert left_bp > right_bp + end + + test "power is right associative (left_bp > right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:pow) + assert left_bp > right_bp + end + + test "addition is left associative (left_bp >= right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:add) + assert left_bp < right_bp + end + + test "multiplication has higher precedence than addition" do + {add_left, _} = Pratt.binding_power(:add) + {mul_left, _} = Pratt.binding_power(:mul) + + assert mul_left > add_left + end + + test "all multiplicative operators have same precedence" do + {mul_left, _} = Pratt.binding_power(:mul) + {div_left, _} = Pratt.binding_power(:div) + {floordiv_left, _} = Pratt.binding_power(:floordiv) + {mod_left, _} = Pratt.binding_power(:mod) + + assert mul_left == div_left + assert mul_left == floordiv_left + assert mul_left == mod_left + end + + test "unary operators have higher precedence than multiplication" do + unary_bp = Pratt.prefix_binding_power(:sub) + {mul_left, _} = Pratt.binding_power(:mul) + + assert unary_bp > mul_left + end + + test "power has higher precedence than unary (special case)" do + unary_bp = Pratt.prefix_binding_power(:sub) + {pow_left, _} = Pratt.binding_power(:pow) + + # Unary minus has lower precedence than power's left binding + # This ensures -2^3 parses as -(2^3) + assert unary_bp < pow_left + end + + test "not and len have same prefix binding power" do + not_bp = Pratt.prefix_binding_power(:not) + len_bp = Pratt.prefix_binding_power(:len) + + assert not_bp == len_bp + end + end +end diff --git a/test/lua/parser/precedence_test.exs b/test/lua/parser/precedence_test.exs index 08ddf37..ba9e925 100644 --- a/test/lua/parser/precedence_test.exs +++ b/test/lua/parser/precedence_test.exs @@ -309,7 +309,8 @@ defmodule Lua.Parser.PrecedenceTest do test "all operators with correct precedence" do # a or b and c < d .. e + f * g ^ h # Should parse as: a or (b and (c < (d .. (e + (f * (g ^ h)))))) - assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b and c < d .. e + f * g ^ h") + assert {:ok, %{block: %{stmts: [stmt]}}} = + Parser.parse("return a or b and c < d .. e + f * g ^ h") assert %{ values: [ diff --git a/test/lua/parser/recovery_test.exs b/test/lua/parser/recovery_test.exs index 7f44314..2ab5914 100644 --- a/test/lua/parser/recovery_test.exs +++ b/test/lua/parser/recovery_test.exs @@ -65,6 +65,18 @@ defmodule Lua.Parser.RecoveryTest do assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) assert [{:eof, _}] = rest end + + test "fails when no boundary found" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:number, 42, %{line: 1, column: 5}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_at_statement(tokens, error) + end end describe "recover_unclosed_delimiter/3" do @@ -76,7 +88,9 @@ defmodule Lua.Parser.RecoveryTest do error = Error.new(:unclosed_delimiter, "Unclosed (", %{line: 1, column: 1}) - assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:eof, _}] = rest end @@ -88,7 +102,9 @@ defmodule Lua.Parser.RecoveryTest do error = Error.new(:unclosed_delimiter, "Unclosed [", %{line: 1, column: 1}) - assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lbracket, error) + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lbracket, error) + assert [{:eof, _}] = rest end @@ -100,7 +116,9 @@ defmodule Lua.Parser.RecoveryTest do error = Error.new(:unclosed_delimiter, "Unclosed {", %{line: 1, column: 1}) - assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lbrace, error) + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lbrace, error) + assert [{:eof, _}] = rest end @@ -114,7 +132,9 @@ defmodule Lua.Parser.RecoveryTest do error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) - assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:eof, _}] = rest end @@ -126,9 +146,82 @@ defmodule Lua.Parser.RecoveryTest do error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) - assert {:recovered, rest, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + assert [{:keyword, :end, _} | _] = rest end + + test "finds closing end keyword at depth 1" do + tokens = [ + {:keyword, :end, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 13}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed function", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :function, error) + + assert [{:eof, _}] = rest + end + + test "handles nested keywords with end" do + # Simulating: function ... if ... end end + tokens = [ + {:keyword, :if, %{line: 1, column: 2}}, + {:keyword, :end, %{line: 1, column: 7}}, + {:keyword, :end, %{line: 1, column: 11}}, + {:eof, %{line: 1, column: 14}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed function", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :function, error) + + assert [{:eof, _}] = rest + end + + test "skips non-delimiter tokens when searching" do + tokens = [ + {:identifier, "x", %{line: 1, column: 2}}, + {:operator, :add, %{line: 1, column: 4}}, + {:number, 42, %{line: 1, column: 6}}, + {:delimiter, :rparen, %{line: 1, column: 8}}, + {:eof, %{line: 1, column: 9}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + + assert [{:eof, _}] = rest + end + + test "handles empty token list" do + tokens = [] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + end + + test "uses closing_delimiter with catch-all for keywords" do + # Test that non-standard delimiters fall through to :end + tokens = [ + {:keyword, :end, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 13}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed if", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :if, error) + + assert [{:eof, _}] = rest + end end describe "recover_missing_keyword/3" do @@ -156,6 +249,24 @@ defmodule Lua.Parser.RecoveryTest do assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) assert [{:keyword, :end, _} | _] = rest end + + test "handles empty token list" do + tokens = [] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + end + + test "handles EOF when searching for keyword" do + tokens = [{:eof, %{line: 1, column: 1}}] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + # EOF is a statement boundary, so it should recover + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:eof, _}] = rest + end end describe "skip_to_statement/1" do diff --git a/test/lua/runtime_exception_test.exs b/test/lua/runtime_exception_test.exs new file mode 100644 index 0000000..d6162e3 --- /dev/null +++ b/test/lua/runtime_exception_test.exs @@ -0,0 +1,434 @@ +defmodule Lua.RuntimeExceptionTest do + use ExUnit.Case, async: true + alias Lua.RuntimeException + + describe "exception/1 with {:lua_error, error, state}" do + test "formats simple lua error with stacktrace" do + lua = Lua.new() + + # Create a lua error by dividing by zero + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "return 1 / 0 * 'string'") + end + end + + test "includes formatted error message and stacktrace" do + lua = Lua.new() + + exception = + try do + Lua.eval!(lua, "error('custom error message')") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "custom error message" + assert exception.state != nil + assert exception.original != nil + end + + test "handles parse errors" do + lua = Lua.new() + + exception = + try do + # Create a parse error by using illegal token + Lua.eval!(lua, "local x = \x01") + rescue + e in Lua.CompilerException -> e + end + + # Parse errors are CompilerException, not RuntimeException + assert exception.__struct__ == Lua.CompilerException + end + + test "handles badarith errors" do + lua = Lua.new() + + exception = + try do + Lua.eval!(lua, "return 'string' + 5") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.state != nil + end + + test "handles illegal index errors" do + lua = Lua.new() + + exception = + try do + # Attempt to index a non-table value + Lua.eval!(lua, "local x = 5; return x.foo") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.state != nil + end + end + + describe "exception/1 with {:api_error, details, state}" do + test "creates exception with api error message" do + state = :luerl.init() + details = "invalid function call" + + exception = RuntimeException.exception({:api_error, details, state}) + + assert exception.message == "Lua API error: invalid function call" + assert exception.original == details + assert exception.state == state + end + + test "handles complex api error details" do + state = :luerl.init() + details = "function returned invalid type: expected table, got nil" + + exception = RuntimeException.exception({:api_error, details, state}) + + assert exception.message == + "Lua API error: function returned invalid type: expected table, got nil" + + assert exception.original == details + assert exception.state == state + end + end + + describe "exception/1 with keyword list [scope:, function:, message:]" do + test "formats error with empty scope" do + exception = + RuntimeException.exception( + scope: [], + function: "my_function", + message: "invalid arguments" + ) + + assert exception.message == "Lua runtime error: my_function() failed, invalid arguments" + + assert exception.original == [ + scope: [], + function: "my_function", + message: "invalid arguments" + ] + + assert exception.state == nil + end + + test "formats error with single scope element" do + exception = + RuntimeException.exception( + scope: ["math"], + function: "sqrt", + message: "negative number not allowed" + ) + + assert exception.message == + "Lua runtime error: math.sqrt() failed, negative number not allowed" + + assert exception.original == [ + scope: ["math"], + function: "sqrt", + message: "negative number not allowed" + ] + + assert exception.state == nil + end + + test "formats error with multiple scope elements" do + exception = + RuntimeException.exception( + scope: ["my", "module", "nested"], + function: "process", + message: "data validation failed" + ) + + assert exception.message == + "Lua runtime error: my.module.nested.process() failed, data validation failed" + + assert exception.original == [ + scope: ["my", "module", "nested"], + function: "process", + message: "data validation failed" + ] + + assert exception.state == nil + end + + test "raises when scope key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + function: "my_function", + message: "invalid arguments" + ) + end + end + + test "raises when function key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + scope: [], + message: "invalid arguments" + ) + end + end + + test "raises when message key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + scope: [], + function: "my_function" + ) + end + end + end + + describe "exception/1 with binary string" do + test "formats simple binary error" do + exception = RuntimeException.exception("something went wrong") + + assert exception.message == "Lua runtime error: something went wrong" + assert exception.original == nil + assert exception.state == nil + end + + test "trims whitespace from binary error" do + exception = RuntimeException.exception(" error with spaces \n") + + assert exception.message == "Lua runtime error: error with spaces" + assert exception.original == nil + assert exception.state == nil + end + + test "handles empty binary string" do + exception = RuntimeException.exception("") + + assert exception.message == "Lua runtime error: " + assert exception.original == nil + assert exception.state == nil + end + + test "handles multi-line binary error" do + error_message = """ + multi-line error + with details + """ + + exception = RuntimeException.exception(error_message) + + assert exception.message == "Lua runtime error: multi-line error\nwith details" + assert exception.original == nil + assert exception.state == nil + end + end + + describe "exception/1 with generic error (fallback clause)" do + test "handles built-in exception types" do + error = ArgumentError.exception("invalid argument") + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: invalid argument" + assert exception.original == error + assert exception.state == nil + end + + test "handles RuntimeError exception" do + error = RuntimeError.exception("runtime failure") + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: runtime failure" + assert exception.original == error + assert exception.state == nil + end + + test "handles KeyError exception" do + error = KeyError.exception(key: :missing, term: %{}) + + exception = RuntimeException.exception(error) + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "key :missing not found" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception atom" do + exception = RuntimeException.exception(:some_error) + + assert exception.message == "Lua runtime error: :some_error" + assert exception.original == :some_error + assert exception.state == nil + end + + test "handles non-exception integer" do + exception = RuntimeException.exception(42) + + assert exception.message == "Lua runtime error: 42" + assert exception.original == 42 + assert exception.state == nil + end + + test "handles non-exception tuple" do + error = {:error, :not_found} + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: {:error, :not_found}" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception map" do + error = %{code: 404, message: "not found"} + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: %{code: 404, message: \"not found\"}" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception list (not keyword list)" do + # A list with integer keys won't match the keyword list pattern + # because Keyword.fetch!/2 expects atom keys + error = [1, 2, 3] + + # This will raise KeyError because the keyword list clause + # will try to match first and fail on Keyword.fetch!/2 + assert_raise KeyError, fn -> + RuntimeException.exception(error) + end + end + + test "handles UndefinedFunctionError" do + error = UndefinedFunctionError.exception(module: MyModule, function: :my_func, arity: 2) + + exception = RuntimeException.exception(error) + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "MyModule.my_func/2" + assert exception.original == error + assert exception.state == nil + end + end + + describe "format_function/2 (private function tested via keyword list exception)" do + test "formats function with empty scope" do + exception = + RuntimeException.exception( + scope: [], + function: "test", + message: "error" + ) + + assert exception.message =~ "test() failed" + end + + test "formats function with single element scope" do + exception = + RuntimeException.exception( + scope: ["module"], + function: "func", + message: "error" + ) + + assert exception.message =~ "module.func() failed" + end + + test "formats function with nested scope" do + exception = + RuntimeException.exception( + scope: ["a", "b", "c"], + function: "method", + message: "error" + ) + + assert exception.message =~ "a.b.c.method() failed" + end + end + + describe "exception message format" do + test "RuntimeException implements Exception protocol" do + exception = RuntimeException.exception("test error") + + assert Exception.message(exception) == "Lua runtime error: test error" + end + + test "can be raised with raise/2" do + assert_raise RuntimeException, "Lua runtime error: test", fn -> + raise RuntimeException, "test" + end + end + + test "can be raised with keyword list" do + assert_raise RuntimeException, fn -> + raise RuntimeException, + scope: ["my", "module"], + function: "test", + message: "failed" + end + end + + test "preserves original error information" do + original = {:error, :custom_reason} + exception = RuntimeException.exception(original) + + assert exception.original == original + assert exception.state == nil + end + end + + describe "integration with Lua module" do + test "RuntimeException is raised for runtime errors" do + lua = Lua.new() + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "error('test error')") + end + end + + test "RuntimeException is raised for type errors" do + lua = Lua.new() + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "return 'string' + 5") + end + end + + test "RuntimeException is raised for sandboxed functions" do + lua = Lua.new(sandboxed: [[:os, :exit]]) + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "os.exit()") + end + end + + test "RuntimeException is raised for empty keys in set!" do + lua = Lua.new() + + assert_raise RuntimeException, "Lua runtime error: Lua.set!/3 cannot have empty keys", fn -> + Lua.set!(lua, [], "value") + end + end + + test "RuntimeException is raised when deflua returns non-encoded data" do + lua = Lua.new() + + lua = + Lua.set!(lua, [:test_func], fn _args -> + # Return non-encoded atom (not a valid Lua value) + [:invalid_atom] + end) + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "test_func()") + end + end + end +end From c86012bc2275d6b8bb657017dc4ce00703be1f6d Mon Sep 17 00:00:00 2001 From: Dave Lucia Date: Fri, 6 Feb 2026 08:33:49 -0500 Subject: [PATCH 4/4] fix CI --- lib/lua/ast/block.ex | 6 +- lib/lua/ast/builder.ex | 72 ++++++------ lib/lua/ast/expr.ex | 32 +++--- lib/lua/ast/pretty_printer.ex | 38 +++--- lib/lua/ast/{stmt.ex => statement.ex} | 36 +++--- lib/lua/ast/walker.ex | 59 +++++----- lib/lua/parser.ex | 58 +++++----- lib/lua/parser/recovery.ex | 1 + test/lua/ast/builder_test.exs | 68 +++++------ test/lua/ast/pretty_printer_test.exs | 4 +- test/lua/ast/walker_test.exs | 34 +++--- test/lua/parser/error_test.exs | 12 +- test/lua/parser/expr_test.exs | 6 +- .../{stmt_test.exs => statement_test.exs} | 108 +++++++++--------- 14 files changed, 277 insertions(+), 257 deletions(-) rename lib/lua/ast/{stmt.ex => statement.ex} (99%) rename test/lua/parser/{stmt_test.exs => statement_test.exs} (81%) diff --git a/lib/lua/ast/block.ex b/lib/lua/ast/block.ex index 1ab1db6..6a0fc25 100644 --- a/lib/lua/ast/block.ex +++ b/lib/lua/ast/block.ex @@ -6,10 +6,10 @@ defmodule Lua.AST.Block do Blocks create a new scope for local variables. """ - alias Lua.AST.{Meta, Stmt} + alias Lua.AST.{Meta, Statement} @type t :: %__MODULE__{ - stmts: [Stmt.t()], + stmts: [Statement.t()], meta: Meta.t() | nil } @@ -26,7 +26,7 @@ defmodule Lua.AST.Block do iex> Lua.AST.Block.new([], %Lua.AST.Meta{}) %Lua.AST.Block{stmts: [], meta: %Lua.AST.Meta{start: nil, end: nil, metadata: %{}}} """ - @spec new([Stmt.t()], Meta.t() | nil) :: t() + @spec new([Statement.t()], Meta.t() | nil) :: t() def new(stmts \\ [], meta \\ nil) do %__MODULE__{stmts: stmts, meta: meta} end diff --git a/lib/lua/ast/builder.ex b/lib/lua/ast/builder.ex index 4d53ffc..221fccb 100644 --- a/lib/lua/ast/builder.ex +++ b/lib/lua/ast/builder.ex @@ -25,7 +25,7 @@ defmodule Lua.AST.Builder do ]) """ - alias Lua.AST.{Chunk, Block, Meta, Expr, Stmt} + alias Lua.AST.{Chunk, Block, Meta, Expr, Statement} # Chunk and Block @@ -36,7 +36,7 @@ defmodule Lua.AST.Builder do chunk([local(["x"], [number(42)])]) """ - @spec chunk([Stmt.t()], Meta.t() | nil) :: Chunk.t() + @spec chunk([Statement.t()], Meta.t() | nil) :: Chunk.t() def chunk(stmts, meta \\ nil) do %Chunk{ block: block(stmts, meta), @@ -54,7 +54,7 @@ defmodule Lua.AST.Builder do assign([var("x")], [number(20)]) ]) """ - @spec block([Stmt.t()], Meta.t() | nil) :: Block.t() + @spec block([Statement.t()], Meta.t() | nil) :: Block.t() def block(stmts, meta \\ nil) do %Block{ stmts: stmts, @@ -261,7 +261,7 @@ defmodule Lua.AST.Builder do # function(...) return ... end function_expr([], [return_stmt([vararg()])], vararg: true) """ - @spec function_expr([String.t()], [Stmt.t()], keyword()) :: Expr.Function.t() + @spec function_expr([String.t()], [Statement.t()], keyword()) :: Expr.Function.t() def function_expr(params, body_stmts, opts \\ []) do params_with_vararg = if Keyword.get(opts, :vararg, false) do @@ -290,9 +290,9 @@ defmodule Lua.AST.Builder do # x, y = 1, 2 assign([var("x"), var("y")], [number(1), number(2)]) """ - @spec assign([Expr.t()], [Expr.t()], Meta.t() | nil) :: Stmt.Assign.t() + @spec assign([Expr.t()], [Expr.t()], Meta.t() | nil) :: Statement.Assign.t() def assign(targets, values, meta \\ nil) do - %Stmt.Assign{ + %Statement.Assign{ targets: targets, values: values, meta: meta @@ -313,9 +313,9 @@ defmodule Lua.AST.Builder do # local x, y = 1, 2 local(["x", "y"], [number(1), number(2)]) """ - @spec local([String.t()], [Expr.t()], Meta.t() | nil) :: Stmt.Local.t() + @spec local([String.t()], [Expr.t()], Meta.t() | nil) :: Statement.Local.t() def local(names, values \\ [], meta \\ nil) do - %Stmt.Local{ + %Statement.Local{ names: names, values: values, meta: meta @@ -332,7 +332,8 @@ defmodule Lua.AST.Builder do return_stmt([binop(:add, var("a"), var("b"))]) ]) """ - @spec local_func(String.t(), [String.t()], [Stmt.t()], keyword()) :: Stmt.LocalFunc.t() + @spec local_func(String.t(), [String.t()], [Statement.t()], keyword()) :: + Statement.LocalFunc.t() def local_func(name, params, body_stmts, opts \\ []) do params_with_vararg = if Keyword.get(opts, :vararg, false) do @@ -341,7 +342,7 @@ defmodule Lua.AST.Builder do params end - %Stmt.LocalFunc{ + %Statement.LocalFunc{ name: name, params: params_with_vararg, body: block(body_stmts), @@ -362,8 +363,8 @@ defmodule Lua.AST.Builder do # function math.add(a, b) return a + b end func_decl(["math", "add"], ["a", "b"], [...]) """ - @spec func_decl(String.t() | [String.t()], [String.t()], [Stmt.t()], keyword()) :: - Stmt.FuncDecl.t() + @spec func_decl(String.t() | [String.t()], [String.t()], [Statement.t()], keyword()) :: + Statement.FuncDecl.t() def func_decl(name, params, body_stmts, opts \\ []) do name_parts = if is_binary(name), do: [name], else: name @@ -376,7 +377,7 @@ defmodule Lua.AST.Builder do is_method = Keyword.get(opts, :is_method, false) - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: name_parts, params: params_with_vararg, body: block(body_stmts), @@ -392,9 +393,9 @@ defmodule Lua.AST.Builder do call_stmt(call(var("print"), [string("hello")])) """ - @spec call_stmt(Expr.Call.t() | Expr.MethodCall.t(), Meta.t() | nil) :: Stmt.CallStmt.t() + @spec call_stmt(Expr.Call.t() | Expr.MethodCall.t(), Meta.t() | nil) :: Statement.CallStmt.t() def call_stmt(call_expr, meta \\ nil) do - %Stmt.CallStmt{ + %Statement.CallStmt{ call: call_expr, meta: meta } @@ -419,9 +420,9 @@ defmodule Lua.AST.Builder do else: [call_stmt(...)] ) """ - @spec if_stmt(Expr.t(), [Stmt.t()], keyword()) :: Stmt.If.t() + @spec if_stmt(Expr.t(), [Statement.t()], keyword()) :: Statement.If.t() def if_stmt(condition, then_stmts, opts \\ []) do - %Stmt.If{ + %Statement.If{ condition: condition, then_block: block(then_stmts), elseifs: Keyword.get(opts, :elseif, []) |> Enum.map(fn {c, s} -> {c, block(s)} end), @@ -441,9 +442,9 @@ defmodule Lua.AST.Builder do [assign([var("x")], [binop(:sub, var("x"), number(1))])] ) """ - @spec while_stmt(Expr.t(), [Stmt.t()], Meta.t() | nil) :: Stmt.While.t() + @spec while_stmt(Expr.t(), [Statement.t()], Meta.t() | nil) :: Statement.While.t() def while_stmt(condition, body_stmts, meta \\ nil) do - %Stmt.While{ + %Statement.While{ condition: condition, body: block(body_stmts), meta: meta @@ -461,9 +462,9 @@ defmodule Lua.AST.Builder do binop(:le, var("x"), number(0)) ) """ - @spec repeat_stmt([Stmt.t()], Expr.t(), Meta.t() | nil) :: Stmt.Repeat.t() + @spec repeat_stmt([Statement.t()], Expr.t(), Meta.t() | nil) :: Statement.Repeat.t() def repeat_stmt(body_stmts, condition, meta \\ nil) do - %Stmt.Repeat{ + %Statement.Repeat{ body: block(body_stmts), condition: condition, meta: meta @@ -483,9 +484,10 @@ defmodule Lua.AST.Builder do # for i = 1, 10, 2 do print(i) end for_num("i", number(1), number(10), [...], step: number(2)) """ - @spec for_num(String.t(), Expr.t(), Expr.t(), [Stmt.t()], keyword()) :: Stmt.ForNum.t() + @spec for_num(String.t(), Expr.t(), Expr.t(), [Statement.t()], keyword()) :: + Statement.ForNum.t() def for_num(var_name, start, limit, body_stmts, opts \\ []) do - %Stmt.ForNum{ + %Statement.ForNum{ var: var_name, start: start, limit: limit, @@ -507,9 +509,9 @@ defmodule Lua.AST.Builder do [call_stmt(call(var("print"), [var("k"), var("v")]))] ) """ - @spec for_in([String.t()], [Expr.t()], [Stmt.t()], Meta.t() | nil) :: Stmt.ForIn.t() + @spec for_in([String.t()], [Expr.t()], [Statement.t()], Meta.t() | nil) :: Statement.ForIn.t() def for_in(vars, iterators, body_stmts, meta \\ nil) do - %Stmt.ForIn{ + %Statement.ForIn{ vars: vars, iterators: iterators, body: block(body_stmts), @@ -528,9 +530,9 @@ defmodule Lua.AST.Builder do call_stmt(call(var("print"), [var("x")])) ]) """ - @spec do_block([Stmt.t()], Meta.t() | nil) :: Stmt.Do.t() + @spec do_block([Statement.t()], Meta.t() | nil) :: Statement.Do.t() def do_block(body_stmts, meta \\ nil) do - %Stmt.Do{ + %Statement.Do{ body: block(body_stmts), meta: meta } @@ -550,23 +552,23 @@ defmodule Lua.AST.Builder do # return x, y return_stmt([var("x"), var("y")]) """ - @spec return_stmt([Expr.t()], Meta.t() | nil) :: Stmt.Return.t() + @spec return_stmt([Expr.t()], Meta.t() | nil) :: Statement.Return.t() def return_stmt(values, meta \\ nil) do - %Stmt.Return{ + %Statement.Return{ values: values, meta: meta } end @doc "Creates a break statement" - @spec break_stmt(Meta.t() | nil) :: Stmt.Break.t() - def break_stmt(meta \\ nil), do: %Stmt.Break{meta: meta} + @spec break_stmt(Meta.t() | nil) :: Statement.Break.t() + def break_stmt(meta \\ nil), do: %Statement.Break{meta: meta} @doc "Creates a goto statement" - @spec goto_stmt(String.t(), Meta.t() | nil) :: Stmt.Goto.t() - def goto_stmt(label, meta \\ nil), do: %Stmt.Goto{label: label, meta: meta} + @spec goto_stmt(String.t(), Meta.t() | nil) :: Statement.Goto.t() + def goto_stmt(label, meta \\ nil), do: %Statement.Goto{label: label, meta: meta} @doc "Creates a label" - @spec label(String.t(), Meta.t() | nil) :: Stmt.Label.t() - def label(name, meta \\ nil), do: %Stmt.Label{name: name, meta: meta} + @spec label(String.t(), Meta.t() | nil) :: Statement.Label.t() + def label(name, meta \\ nil), do: %Statement.Label{name: name, meta: meta} end diff --git a/lib/lua/ast/expr.ex b/lib/lua/ast/expr.ex index 90b238d..1b8deb1 100644 --- a/lib/lua/ast/expr.ex +++ b/lib/lua/ast/expr.ex @@ -7,22 +7,6 @@ defmodule Lua.AST.Expr do alias Lua.AST.Meta - @type t :: - Nil.t() - | Bool.t() - | Number.t() - | String.t() - | Var.t() - | BinOp.t() - | UnOp.t() - | Table.t() - | Call.t() - | MethodCall.t() - | Index.t() - | Property.t() - | Function.t() - | Vararg.t() - defmodule Nil do @moduledoc "Represents the `nil` literal" defstruct [:meta] @@ -213,4 +197,20 @@ defmodule Lua.AST.Expr do defstruct [:meta] @type t :: %__MODULE__{meta: Meta.t() | nil} end + + @type t :: + Nil.t() + | Bool.t() + | Number.t() + | String.t() + | Var.t() + | BinOp.t() + | UnOp.t() + | Table.t() + | Call.t() + | MethodCall.t() + | Index.t() + | Property.t() + | Function.t() + | Vararg.t() end diff --git a/lib/lua/ast/pretty_printer.ex b/lib/lua/ast/pretty_printer.ex index a97219f..2526dc0 100644 --- a/lib/lua/ast/pretty_printer.ex +++ b/lib/lua/ast/pretty_printer.ex @@ -18,13 +18,13 @@ defmodule Lua.AST.PrettyPrinter do PrettyPrinter.print(ast, indent: 4) """ - alias Lua.AST.{Chunk, Block, Expr, Stmt} + alias Lua.AST.{Chunk, Block, Expr, Statement} @type ast_node :: Chunk.t() | Block.t() | Expr.t() - | Stmt.t() + | Statement.t() @type opts :: [ indent: pos_integer() @@ -169,14 +169,14 @@ defmodule Lua.AST.PrettyPrinter do # Statements - defp do_print(%Stmt.Assign{targets: targets, values: values}, level, indent_size) do + defp do_print(%Statement.Assign{targets: targets, values: values}, level, indent_size) do targets_str = Enum.map(targets, &do_print(&1, level, indent_size)) |> Enum.join(", ") values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") "#{indent(level, indent_size)}#{targets_str} = #{values_str}" end - defp do_print(%Stmt.Local{names: names, values: values}, level, indent_size) do + defp do_print(%Statement.Local{names: names, values: values}, level, indent_size) do names_str = Enum.join(names, ", ") if values && values != [] do @@ -187,7 +187,7 @@ defmodule Lua.AST.PrettyPrinter do end end - defp do_print(%Stmt.LocalFunc{name: name, params: params, body: body}, level, indent_size) do + defp do_print(%Statement.LocalFunc{name: name, params: params, body: body}, level, indent_size) do params_str = params |> Enum.map(fn @@ -201,7 +201,7 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}local function #{name}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.FuncDecl{name: name, params: params, body: body}, level, indent_size) do + defp do_print(%Statement.FuncDecl{name: name, params: params, body: body}, level, indent_size) do params_str = params |> Enum.map(fn @@ -215,12 +215,12 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}function #{format_func_name(name)}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.CallStmt{call: call}, level, indent_size) do + defp do_print(%Statement.CallStmt{call: call}, level, indent_size) do "#{indent(level, indent_size)}#{do_print(call, level, indent_size)}" end defp do_print( - %Stmt.If{ + %Statement.If{ condition: cond, then_block: then_block, elseifs: elseifs, @@ -259,14 +259,14 @@ defmodule Lua.AST.PrettyPrinter do Enum.join(parts, "") <> "#{indent(level, indent_size)}end" end - defp do_print(%Stmt.While{condition: cond, body: body}, level, indent_size) do + defp do_print(%Statement.While{condition: cond, body: body}, level, indent_size) do cond_str = do_print(cond, level, indent_size) body_str = print_block_body(body, level + 1, indent_size) "#{indent(level, indent_size)}while #{cond_str} do\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.Repeat{body: body, condition: cond}, level, indent_size) do + defp do_print(%Statement.Repeat{body: body, condition: cond}, level, indent_size) do body_str = print_block_body(body, level + 1, indent_size) cond_str = do_print(cond, level, indent_size) @@ -274,7 +274,7 @@ defmodule Lua.AST.PrettyPrinter do end defp do_print( - %Stmt.ForNum{var: var, start: start, limit: limit, step: step, body: body}, + %Statement.ForNum{var: var, start: start, limit: limit, step: step, body: body}, level, indent_size ) do @@ -292,7 +292,11 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}for #{var} = #{start_str}, #{limit_str}#{step_str} do\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.ForIn{vars: vars, iterators: iterators, body: body}, level, indent_size) do + defp do_print( + %Statement.ForIn{vars: vars, iterators: iterators, body: body}, + level, + indent_size + ) do vars_str = Enum.join(vars, ", ") iterators_str = Enum.map(iterators, &do_print(&1, level, indent_size)) |> Enum.join(", ") body_str = print_block_body(body, level + 1, indent_size) @@ -300,13 +304,13 @@ defmodule Lua.AST.PrettyPrinter do "#{indent(level, indent_size)}for #{vars_str} in #{iterators_str} do\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.Do{body: body}, level, indent_size) do + defp do_print(%Statement.Do{body: body}, level, indent_size) do body_str = print_block_body(body, level + 1, indent_size) "#{indent(level, indent_size)}do\n#{body_str}#{indent(level, indent_size)}end" end - defp do_print(%Stmt.Return{values: values}, level, indent_size) do + defp do_print(%Statement.Return{values: values}, level, indent_size) do if values == [] do "#{indent(level, indent_size)}return" else @@ -315,15 +319,15 @@ defmodule Lua.AST.PrettyPrinter do end end - defp do_print(%Stmt.Break{}, level, indent_size) do + defp do_print(%Statement.Break{}, level, indent_size) do "#{indent(level, indent_size)}break" end - defp do_print(%Stmt.Goto{label: label}, level, indent_size) do + defp do_print(%Statement.Goto{label: label}, level, indent_size) do "#{indent(level, indent_size)}goto #{label}" end - defp do_print(%Stmt.Label{name: name}, level, indent_size) do + defp do_print(%Statement.Label{name: name}, level, indent_size) do "#{indent(level, indent_size)}::#{name}::" end diff --git a/lib/lua/ast/stmt.ex b/lib/lua/ast/statement.ex similarity index 99% rename from lib/lua/ast/stmt.ex rename to lib/lua/ast/statement.ex index 331c9f3..567c463 100644 --- a/lib/lua/ast/stmt.ex +++ b/lib/lua/ast/statement.ex @@ -1,4 +1,4 @@ -defmodule Lua.AST.Stmt do +defmodule Lua.AST.Statement do @moduledoc """ Statement AST nodes for Lua. @@ -7,23 +7,6 @@ defmodule Lua.AST.Stmt do alias Lua.AST.{Meta, Expr, Block} - @type t :: - Assign.t() - | Local.t() - | LocalFunc.t() - | FuncDecl.t() - | CallStmt.t() - | If.t() - | While.t() - | Repeat.t() - | ForNum.t() - | ForIn.t() - | Do.t() - | Return.t() - | Break.t() - | Goto.t() - | Label.t() - defmodule Assign do @moduledoc """ Represents an assignment statement: `targets = values` @@ -259,4 +242,21 @@ defmodule Lua.AST.Stmt do meta: Meta.t() | nil } end + + @type t :: + Assign.t() + | Local.t() + | LocalFunc.t() + | FuncDecl.t() + | CallStmt.t() + | If.t() + | While.t() + | Repeat.t() + | ForNum.t() + | ForIn.t() + | Do.t() + | Return.t() + | Break.t() + | Goto.t() + | Label.t() end diff --git a/lib/lua/ast/walker.ex b/lib/lua/ast/walker.ex index ddd0aed..161ea74 100644 --- a/lib/lua/ast/walker.ex +++ b/lib/lua/ast/walker.ex @@ -27,13 +27,13 @@ defmodule Lua.AST.Walker do Walker.walk(ast, fn node -> ... end, order: :post) """ - alias Lua.AST.{Chunk, Block, Expr, Stmt} + alias Lua.AST.{Chunk, Block, Expr, Statement} @type ast_node :: Chunk.t() | Block.t() | Expr.t() - | Stmt.t() + | Statement.t() @type visitor :: (ast_node -> any()) @type mapper :: (ast_node -> ast_node) @@ -167,29 +167,29 @@ defmodule Lua.AST.Walker do %{expr | body: do_map(body, mapper)} # Statements - %Stmt.Assign{targets: targets, values: values} = stmt -> + %Statement.Assign{targets: targets, values: values} = stmt -> %{ stmt | targets: Enum.map(targets, &do_map(&1, mapper)), values: Enum.map(values, &do_map(&1, mapper)) } - %Stmt.Local{values: values} = stmt when is_list(values) -> + %Statement.Local{values: values} = stmt when is_list(values) -> %{stmt | values: Enum.map(values, &do_map(&1, mapper))} - %Stmt.Local{} = stmt -> + %Statement.Local{} = stmt -> stmt - %Stmt.LocalFunc{body: body} = stmt -> + %Statement.LocalFunc{body: body} = stmt -> %{stmt | body: do_map(body, mapper)} - %Stmt.FuncDecl{body: body} = stmt -> + %Statement.FuncDecl{body: body} = stmt -> %{stmt | body: do_map(body, mapper)} - %Stmt.CallStmt{call: call} = stmt -> + %Statement.CallStmt{call: call} = stmt -> %{stmt | call: do_map(call, mapper)} - %Stmt.If{ + %Statement.If{ condition: cond, then_block: then_block, elseifs: elseifs, @@ -208,13 +208,13 @@ defmodule Lua.AST.Walker do else_block: mapped_else } - %Stmt.While{condition: cond, body: body} = stmt -> + %Statement.While{condition: cond, body: body} = stmt -> %{stmt | condition: do_map(cond, mapper), body: do_map(body, mapper)} - %Stmt.Repeat{body: body, condition: cond} = stmt -> + %Statement.Repeat{body: body, condition: cond} = stmt -> %{stmt | body: do_map(body, mapper), condition: do_map(cond, mapper)} - %Stmt.ForNum{var: _var, start: start, limit: limit, step: step, body: body} = stmt -> + %Statement.ForNum{var: _var, start: start, limit: limit, step: step, body: body} = stmt -> mapped_step = if step, do: do_map(step, mapper), else: nil %{ @@ -225,17 +225,17 @@ defmodule Lua.AST.Walker do body: do_map(body, mapper) } - %Stmt.ForIn{vars: _vars, iterators: iterators, body: body} = stmt -> + %Statement.ForIn{vars: _vars, iterators: iterators, body: body} = stmt -> %{ stmt | iterators: Enum.map(iterators, &do_map(&1, mapper)), body: do_map(body, mapper) } - %Stmt.Do{body: body} = stmt -> + %Statement.Do{body: body} = stmt -> %{stmt | body: do_map(body, mapper)} - %Stmt.Return{values: values} = stmt -> + %Statement.Return{values: values} = stmt -> %{stmt | values: Enum.map(values, &do_map(&1, mapper))} # Leaf nodes (no children) @@ -291,41 +291,46 @@ defmodule Lua.AST.Walker do [body] # Statements with children - %Stmt.Assign{targets: targets, values: values} -> + %Statement.Assign{targets: targets, values: values} -> targets ++ values - %Stmt.Local{values: values} when is_list(values) -> + %Statement.Local{values: values} when is_list(values) -> values - %Stmt.LocalFunc{body: body} -> + %Statement.LocalFunc{body: body} -> [body] - %Stmt.FuncDecl{body: body} -> + %Statement.FuncDecl{body: body} -> [body] - %Stmt.CallStmt{call: call} -> + %Statement.CallStmt{call: call} -> [call] - %Stmt.If{condition: cond, then_block: then_block, elseifs: elseifs, else_block: else_block} -> + %Statement.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + } -> elseif_nodes = Enum.flat_map(elseifs, fn {c, b} -> [c, b] end) [cond, then_block | elseif_nodes] ++ if(else_block, do: [else_block], else: []) - %Stmt.While{condition: cond, body: body} -> + %Statement.While{condition: cond, body: body} -> [cond, body] - %Stmt.Repeat{body: body, condition: cond} -> + %Statement.Repeat{body: body, condition: cond} -> [body, cond] - %Stmt.ForNum{start: start, limit: limit, step: step, body: body} -> + %Statement.ForNum{start: start, limit: limit, step: step, body: body} -> [start, limit] ++ if(step, do: [step], else: []) ++ [body] - %Stmt.ForIn{iterators: iterators, body: body} -> + %Statement.ForIn{iterators: iterators, body: body} -> iterators ++ [body] - %Stmt.Do{body: body} -> + %Statement.Do{body: body} -> [body] - %Stmt.Return{values: values} -> + %Statement.Return{values: values} -> values # Leaf nodes (no children) diff --git a/lib/lua/parser.ex b/lib/lua/parser.ex index bc22f79..c0006ea 100644 --- a/lib/lua/parser.ex +++ b/lib/lua/parser.ex @@ -5,7 +5,7 @@ defmodule Lua.Parser do Uses Pratt parsing for operator precedence in expressions. """ - alias Lua.AST.{Meta, Expr, Stmt, Block, Chunk} + alias Lua.AST.{Meta, Expr, Statement, Block, Chunk} alias Lua.Parser.Pratt alias Lua.Lexer @@ -164,19 +164,19 @@ defmodule Lua.Parser do case peek(rest) do # End of block or statement {:keyword, terminator, _} when terminator in [:end, :else, :elseif, :until] -> - {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest} + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest} {:eof, _} -> - {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest} + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest} {:delimiter, :semicolon, _} -> {_, rest2} = consume(rest) - {:ok, %Stmt.Return{values: [], meta: Meta.new(pos)}, rest2} + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest2} _ -> case parse_expr_list(rest) do {:ok, exprs, rest2} -> - {:ok, %Stmt.Return{values: exprs, meta: Meta.new(pos)}, rest2} + {:ok, %Statement.Return{values: exprs, meta: Meta.new(pos)}, rest2} {:error, reason} -> {:error, reason} @@ -197,7 +197,8 @@ defmodule Lua.Parser do {:ok, _, rest6} <- expect(rest5, :delimiter, :rparen), {:ok, body, rest7} <- parse_block(rest6), {:ok, _, rest8} <- expect(rest7, :keyword, :end) do - {:ok, %Stmt.LocalFunc{name: name, params: params, body: body, meta: Meta.new(pos)}, + {:ok, + %Statement.LocalFunc{name: name, params: params, body: body, meta: Meta.new(pos)}, rest8} end @@ -215,7 +216,8 @@ defmodule Lua.Parser do case parse_expr_list(rest3) do {:ok, values, rest4} -> - {:ok, %Stmt.Local{names: names, values: values, meta: Meta.new(pos)}, rest4} + {:ok, %Statement.Local{names: names, values: values, meta: Meta.new(pos)}, + rest4} {:error, reason} -> {:error, reason} @@ -223,7 +225,7 @@ defmodule Lua.Parser do _ -> # Local without initialization - {:ok, %Stmt.Local{names: names, values: [], meta: Meta.new(pos)}, rest2} + {:ok, %Statement.Local{names: names, values: [], meta: Meta.new(pos)}, rest2} end {:error, reason} -> @@ -244,7 +246,7 @@ defmodule Lua.Parser do case expect(rest5, :keyword, :end) do {:ok, _, rest6} -> {:ok, - %Stmt.If{ + %Statement.If{ condition: condition, then_block: then_block, elseifs: elseifs, @@ -291,7 +293,7 @@ defmodule Lua.Parser do {:ok, _, rest3} <- expect(rest2, :keyword, :do), {:ok, body, rest4} <- parse_block(rest3), {:ok, _, rest5} <- expect(rest4, :keyword, :end) do - {:ok, %Stmt.While{condition: condition, body: body, meta: Meta.new(pos)}, rest5} + {:ok, %Statement.While{condition: condition, body: body, meta: Meta.new(pos)}, rest5} end end @@ -299,7 +301,7 @@ defmodule Lua.Parser do with {:ok, body, rest2} <- parse_block(rest), {:ok, _, rest3} <- expect(rest2, :keyword, :until), {:ok, condition, rest4} <- parse_expr(rest3) do - {:ok, %Stmt.Repeat{body: body, condition: condition, meta: Meta.new(pos)}, rest4} + {:ok, %Statement.Repeat{body: body, condition: condition, meta: Meta.new(pos)}, rest4} end end @@ -319,7 +321,7 @@ defmodule Lua.Parser do {:ok, body, rest9} <- parse_block(rest8), {:ok, _, rest10} <- expect(rest9, :keyword, :end) do {:ok, - %Stmt.ForNum{ + %Statement.ForNum{ var: var, start: start, limit: limit, @@ -378,8 +380,12 @@ defmodule Lua.Parser do {:ok, body, rest4} <- parse_block(rest3), {:ok, _, rest5} <- expect(rest4, :keyword, :end) do {:ok, - %Stmt.ForIn{vars: vars, iterators: iterators, body: body, meta: Meta.new(start_pos)}, - rest5} + %Statement.ForIn{ + vars: vars, + iterators: iterators, + body: body, + meta: Meta.new(start_pos) + }, rest5} end _ -> @@ -396,7 +402,7 @@ defmodule Lua.Parser do {:ok, body, rest6} <- parse_block(rest5), {:ok, _, rest7} <- expect(rest6, :keyword, :end) do {:ok, - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: name_parts, params: params, body: body, @@ -452,18 +458,18 @@ defmodule Lua.Parser do defp parse_do([{:keyword, :do, pos} | rest]) do with {:ok, body, rest2} <- parse_block(rest), {:ok, _, rest3} <- expect(rest2, :keyword, :end) do - {:ok, %Stmt.Do{body: body, meta: Meta.new(pos)}, rest3} + {:ok, %Statement.Do{body: body, meta: Meta.new(pos)}, rest3} end end defp parse_break([{:keyword, :break, pos} | rest]) do - {:ok, %Stmt.Break{meta: Meta.new(pos)}, rest} + {:ok, %Statement.Break{meta: Meta.new(pos)}, rest} end defp parse_goto([{:keyword, :goto, pos} | rest]) do case expect(rest, :identifier) do {:ok, {_, label, _}, rest2} -> - {:ok, %Stmt.Goto{label: label, meta: Meta.new(pos)}, rest2} + {:ok, %Statement.Goto{label: label, meta: Meta.new(pos)}, rest2} {:error, reason} -> {:error, reason} @@ -475,7 +481,7 @@ defmodule Lua.Parser do {:ok, {_, name, _}, rest2} -> case expect(rest2, :delimiter, :double_colon) do {:ok, _, rest3} -> - {:ok, %Stmt.Label{name: name, meta: Meta.new(pos)}, rest3} + {:ok, %Statement.Label{name: name, meta: Meta.new(pos)}, rest3} {:error, reason} -> {:error, reason} @@ -504,10 +510,10 @@ defmodule Lua.Parser do # It's a call statement (or error if not a call) case expr do %Expr.Call{} = call -> - {:ok, %Stmt.CallStmt{call: call, meta: nil}, rest} + {:ok, %Statement.CallStmt{call: call, meta: nil}, rest} %Expr.MethodCall{} = call -> - {:ok, %Stmt.CallStmt{call: call, meta: nil}, rest} + {:ok, %Statement.CallStmt{call: call, meta: nil}, rest} _ -> {:error, {:unexpected_expression, "Expression statement must be a function call"}} @@ -541,7 +547,7 @@ defmodule Lua.Parser do defp parse_assignment(targets, [{:operator, :assign, _} | rest]) do case parse_expr_list(rest) do {:ok, values, rest2} -> - {:ok, %Stmt.Assign{targets: targets, values: values, meta: nil}, rest2} + {:ok, %Statement.Assign{targets: targets, values: values, meta: nil}, rest2} {:error, reason} -> {:error, reason} @@ -1096,12 +1102,8 @@ defmodule Lua.Parser do ) end - defp convert_error({:lexer_error, reason}, code) do - convert_lexer_error(reason, code) - end - - defp convert_error({:not_implemented, feature}, _code) do - Error.new(:invalid_syntax, "Feature not yet implemented: #{feature}", nil) + defp convert_error({:unexpected_expression, message}, _code) do + Error.new(:invalid_syntax, message, nil) end defp convert_error(other, _code) do diff --git a/lib/lua/parser/recovery.ex b/lib/lua/parser/recovery.ex index 0ae193b..48a6970 100644 --- a/lib/lua/parser/recovery.ex +++ b/lib/lua/parser/recovery.ex @@ -8,6 +8,7 @@ defmodule Lua.Parser.Recovery do alias Lua.Parser.Error alias Lua.Lexer + alias Lua.AST.Meta @type token :: Lexer.token() @type recovery_result :: {:recovered, [token()], [Error.t()]} | {:failed, [Error.t()]} diff --git a/test/lua/ast/builder_test.exs b/test/lua/ast/builder_test.exs index ceaf392..3184030 100644 --- a/test/lua/ast/builder_test.exs +++ b/test/lua/ast/builder_test.exs @@ -2,17 +2,17 @@ defmodule Lua.AST.BuilderTest do use ExUnit.Case, async: true import Lua.AST.Builder - alias Lua.AST.{Chunk, Block, Expr, Stmt} + alias Lua.AST.{Chunk, Block, Expr, Statement} describe "chunk and block" do test "creates a chunk" do ast = chunk([local(["x"], [number(42)])]) - assert %Chunk{block: %Block{stmts: [%Stmt.Local{}]}} = ast + assert %Chunk{block: %Block{stmts: [%Statement.Local{}]}} = ast end test "creates a block" do blk = block([local(["x"], [number(42)])]) - assert %Block{stmts: [%Stmt.Local{}]} = blk + assert %Block{stmts: [%Statement.Local{}]} = blk end end @@ -194,7 +194,7 @@ defmodule Lua.AST.BuilderTest do assert %Expr.Function{ params: ["x"], - body: %Block{stmts: [%Stmt.Return{}]} + body: %Block{stmts: [%Statement.Return{}]} } = fn_expr end @@ -213,7 +213,7 @@ defmodule Lua.AST.BuilderTest do test "creates assignment" do stmt = assign([var("x")], [number(42)]) - assert %Stmt.Assign{ + assert %Statement.Assign{ targets: [%Expr.Var{name: "x"}], values: [%Expr.Number{value: 42}] } = stmt @@ -221,23 +221,23 @@ defmodule Lua.AST.BuilderTest do test "creates multiple assignment" do stmt = assign([var("x"), var("y")], [number(1), number(2)]) - assert %Stmt.Assign{targets: [_, _], values: [_, _]} = stmt + assert %Statement.Assign{targets: [_, _], values: [_, _]} = stmt end test "creates local declaration" do stmt = local(["x"], [number(42)]) - assert %Stmt.Local{names: ["x"], values: [%Expr.Number{value: 42}]} = stmt + assert %Statement.Local{names: ["x"], values: [%Expr.Number{value: 42}]} = stmt end test "creates local declaration without value" do stmt = local(["x"], []) - assert %Stmt.Local{names: ["x"], values: []} = stmt + assert %Statement.Local{names: ["x"], values: []} = stmt end test "creates local function" do stmt = local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) - assert %Stmt.LocalFunc{ + assert %Statement.LocalFunc{ name: "add", params: ["a", "b"], body: %Block{} @@ -246,42 +246,42 @@ defmodule Lua.AST.BuilderTest do test "creates function declaration with string name" do stmt = func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) - assert %Stmt.FuncDecl{name: ["add"], params: ["a", "b"]} = stmt + assert %Statement.FuncDecl{name: ["add"], params: ["a", "b"]} = stmt end test "creates function declaration with path name" do stmt = func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) - assert %Stmt.FuncDecl{name: ["math", "add"]} = stmt + assert %Statement.FuncDecl{name: ["math", "add"]} = stmt end test "creates call statement" do stmt = call_stmt(call(var("print"), [string("hello")])) - assert %Stmt.CallStmt{call: %Expr.Call{}} = stmt + assert %Statement.CallStmt{call: %Expr.Call{}} = stmt end test "creates return statement" do stmt = return_stmt([]) - assert %Stmt.Return{values: []} = stmt + assert %Statement.Return{values: []} = stmt stmt = return_stmt([number(42)]) - assert %Stmt.Return{values: [%Expr.Number{value: 42}]} = stmt + assert %Statement.Return{values: [%Expr.Number{value: 42}]} = stmt end test "creates break statement" do stmt = break_stmt() - assert %Stmt.Break{} = stmt + assert %Statement.Break{} = stmt end test "creates goto statement" do stmt = goto_stmt("label") - assert %Stmt.Goto{label: "label"} = stmt + assert %Statement.Goto{label: "label"} = stmt end test "creates label" do stmt = label("label") - assert %Stmt.Label{name: "label"} = stmt + assert %Statement.Label{name: "label"} = stmt end end @@ -289,9 +289,9 @@ defmodule Lua.AST.BuilderTest do test "creates if statement" do stmt = if_stmt(var("x"), [return_stmt([number(1)])]) - assert %Stmt.If{ + assert %Statement.If{ condition: %Expr.Var{name: "x"}, - then_block: %Block{stmts: [%Stmt.Return{}]}, + then_block: %Block{stmts: [%Statement.Return{}]}, elseifs: [], else_block: nil } = stmt @@ -305,7 +305,7 @@ defmodule Lua.AST.BuilderTest do else: [return_stmt([number(0)])] ) - assert %Stmt.If{else_block: %Block{}} = stmt + assert %Statement.If{else_block: %Block{}} = stmt end test "creates if-elseif-else statement" do @@ -317,7 +317,7 @@ defmodule Lua.AST.BuilderTest do else: [return_stmt([number(0)])] ) - assert %Stmt.If{ + assert %Statement.If{ elseifs: [{_, %Block{}}], else_block: %Block{} } = stmt @@ -329,7 +329,7 @@ defmodule Lua.AST.BuilderTest do assign([var("x")], [binop(:sub, var("x"), number(1))]) ]) - assert %Stmt.While{ + assert %Statement.While{ condition: %Expr.BinOp{op: :gt}, body: %Block{} } = stmt @@ -342,7 +342,7 @@ defmodule Lua.AST.BuilderTest do binop(:le, var("x"), number(0)) ) - assert %Stmt.Repeat{ + assert %Statement.Repeat{ body: %Block{}, condition: %Expr.BinOp{op: :le} } = stmt @@ -354,7 +354,7 @@ defmodule Lua.AST.BuilderTest do call_stmt(call(var("print"), [var("i")])) ]) - assert %Stmt.ForNum{ + assert %Statement.ForNum{ var: "i", start: %Expr.Number{value: 1}, limit: %Expr.Number{value: 10}, @@ -371,9 +371,11 @@ defmodule Lua.AST.BuilderTest do number(10), [ call_stmt(call(var("print"), [var("i")])) - ], step: number(2)) + ], + step: number(2) + ) - assert %Stmt.ForNum{step: %Expr.Number{value: 2}} = stmt + assert %Statement.ForNum{step: %Expr.Number{value: 2}} = stmt end test "creates generic for loop" do @@ -384,7 +386,7 @@ defmodule Lua.AST.BuilderTest do [call_stmt(call(var("print"), [var("k"), var("v")]))] ) - assert %Stmt.ForIn{ + assert %Statement.ForIn{ vars: ["k", "v"], iterators: [%Expr.Call{}], body: %Block{} @@ -398,7 +400,7 @@ defmodule Lua.AST.BuilderTest do call_stmt(call(var("print"), [var("x")])) ]) - assert %Stmt.Do{body: %Block{stmts: [_, _]}} = stmt + assert %Statement.Do{body: %Block{stmts: [_, _]}} = stmt end end @@ -419,11 +421,11 @@ defmodule Lua.AST.BuilderTest do assert %Chunk{ block: %Block{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: ["outer"], body: %Block{ stmts: [ - %Stmt.Return{ + %Statement.Return{ values: [%Expr.Function{}] } ] @@ -451,7 +453,7 @@ defmodule Lua.AST.BuilderTest do assert %Chunk{ block: %Block{ stmts: [ - %Stmt.If{ + %Statement.If{ elseifs: [{_, _}, {_, _}], else_block: %Block{} } @@ -478,9 +480,9 @@ defmodule Lua.AST.BuilderTest do assert %Chunk{ block: %Block{ stmts: [ - %Stmt.ForNum{ + %Statement.ForNum{ body: %Block{ - stmts: [%Stmt.ForNum{}] + stmts: [%Statement.ForNum{}] } } ] diff --git a/test/lua/ast/pretty_printer_test.exs b/test/lua/ast/pretty_printer_test.exs index 644fb9d..996fd76 100644 --- a/test/lua/ast/pretty_printer_test.exs +++ b/test/lua/ast/pretty_printer_test.exs @@ -1074,12 +1074,12 @@ defmodule Lua.AST.PrettyPrinterTest do describe "binary name for function declaration" do test "prints function with string name" do # When passing a simple string name (not a list) - alias Lua.AST.Stmt + alias Lua.AST.Statement ast = %Lua.AST.Chunk{ block: %Lua.AST.Block{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: "simple", params: [], body: %Lua.AST.Block{stmts: [return_stmt([])]}, diff --git a/test/lua/ast/walker_test.exs b/test/lua/ast/walker_test.exs index 292b4e8..50a545f 100644 --- a/test/lua/ast/walker_test.exs +++ b/test/lua/ast/walker_test.exs @@ -2,7 +2,7 @@ defmodule Lua.AST.WalkerTest do use ExUnit.Case, async: true import Lua.AST.Builder - alias Lua.AST.{Walker, Expr, Stmt} + alias Lua.AST.{Walker, Expr, Statement} describe "walk/2" do test "visits all nodes in pre-order" do @@ -158,7 +158,7 @@ defmodule Lua.AST.WalkerTest do end) # Extract the if statement - [%Stmt.If{condition: %Expr.Bool{value: true}}] = transformed.block.stmts + [%Statement.If{condition: %Expr.Bool{value: true}}] = transformed.block.stmts # Number should be transformed numbers = @@ -227,7 +227,7 @@ defmodule Lua.AST.WalkerTest do # Build map of local declarations: name -> value locals = Walker.reduce(ast, %{}, fn - %Stmt.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> + %Statement.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> Map.put(acc, name, n) _, acc -> @@ -525,7 +525,7 @@ defmodule Lua.AST.WalkerTest do # Count local statements local_count = Walker.reduce(ast, 0, fn - %Stmt.Local{}, acc -> acc + 1 + %Statement.Local{}, acc -> acc + 1 _, acc -> acc end) @@ -548,14 +548,14 @@ defmodule Lua.AST.WalkerTest do # Transform should preserve empty values list transformed = Walker.map(ast, fn - %Stmt.Local{values: []} = node -> node + %Statement.Local{values: []} = node -> node node -> node end) # Extract local statement locals = Walker.reduce(transformed, [], fn - %Stmt.Local{names: names, values: values}, acc -> [{names, values} | acc] + %Statement.Local{names: names, values: values}, acc -> [{names, values} | acc] _, acc -> acc end) @@ -688,7 +688,7 @@ defmodule Lua.AST.WalkerTest do # Verify step is nil step_is_nil = Walker.reduce(ast, false, fn - %Stmt.ForNum{step: nil}, _acc -> true + %Statement.ForNum{step: nil}, _acc -> true _, acc -> acc end) @@ -813,7 +813,7 @@ defmodule Lua.AST.WalkerTest do # Count do statements do_count = Walker.reduce(ast, 0, fn - %Stmt.Do{}, acc -> acc + 1 + %Statement.Do{}, acc -> acc + 1 _, acc -> acc end) @@ -847,7 +847,7 @@ defmodule Lua.AST.WalkerTest do # Count break statements break_count = Walker.reduce(ast, 0, fn - %Stmt.Break{}, acc -> acc + 1 + %Statement.Break{}, acc -> acc + 1 _, acc -> acc end) @@ -866,7 +866,7 @@ defmodule Lua.AST.WalkerTest do break_count = Walker.reduce(transformed, 0, fn - %Stmt.Break{}, acc -> acc + 1 + %Statement.Break{}, acc -> acc + 1 _, acc -> acc end) @@ -880,7 +880,7 @@ defmodule Lua.AST.WalkerTest do # Count goto statements goto_labels = Walker.reduce(ast, [], fn - %Stmt.Goto{label: label}, acc -> [label | acc] + %Statement.Goto{label: label}, acc -> [label | acc] _, acc -> acc end) @@ -899,7 +899,7 @@ defmodule Lua.AST.WalkerTest do labels = Walker.reduce(transformed, [], fn - %Stmt.Goto{label: label}, acc -> [label | acc] + %Statement.Goto{label: label}, acc -> [label | acc] _, acc -> acc end) @@ -913,7 +913,7 @@ defmodule Lua.AST.WalkerTest do # Count labels labels = Walker.reduce(ast, [], fn - %Stmt.Label{name: name}, acc -> [name | acc] + %Statement.Label{name: name}, acc -> [name | acc] _, acc -> acc end) @@ -932,7 +932,7 @@ defmodule Lua.AST.WalkerTest do labels = Walker.reduce(transformed, [], fn - %Stmt.Label{name: name}, acc -> [name | acc] + %Statement.Label{name: name}, acc -> [name | acc] _, acc -> acc end) @@ -984,7 +984,7 @@ defmodule Lua.AST.WalkerTest do # Verify structure if_stmts = Walker.reduce(ast, [], fn - %Stmt.If{elseifs: elseifs, else_block: else_block}, acc -> + %Statement.If{elseifs: elseifs, else_block: else_block}, acc -> [{elseifs, else_block} | acc] _, acc -> @@ -1054,7 +1054,7 @@ defmodule Lua.AST.WalkerTest do locals = Walker.reduce(transformed, [], fn - %Stmt.Local{names: names, values: values}, acc -> [{names, values} | acc] + %Statement.Local{names: names, values: values}, acc -> [{names, values} | acc] _, acc -> acc end) @@ -1200,7 +1200,7 @@ defmodule Lua.AST.WalkerTest do # Should walk through CallStmt to MethodCall call_stmt_count = Walker.reduce(ast, 0, fn - %Stmt.CallStmt{}, acc -> acc + 1 + %Statement.CallStmt{}, acc -> acc + 1 _, acc -> acc end) diff --git a/test/lua/parser/error_test.exs b/test/lua/parser/error_test.exs index 5bd15c5..754f17a 100644 --- a/test/lua/parser/error_test.exs +++ b/test/lua/parser/error_test.exs @@ -153,10 +153,14 @@ defmodule Lua.Parser.ErrorTest do code = "if x then" assert {:error, msg} = Parser.parse(code) - assert msg =~ "\e[31m" # Red for errors - assert msg =~ "\e[1m" # Bold - assert msg =~ "\e[0m" # Reset - assert msg =~ "\e[36m" # Cyan for suggestions + # Red for errors + assert msg =~ "\e[31m" + # Bold + assert msg =~ "\e[1m" + # Reset + assert msg =~ "\e[0m" + # Cyan for suggestions + assert msg =~ "\e[36m" end end diff --git a/test/lua/parser/expr_test.exs b/test/lua/parser/expr_test.exs index 02cec9d..cf941b3 100644 --- a/test/lua/parser/expr_test.exs +++ b/test/lua/parser/expr_test.exs @@ -1,15 +1,15 @@ defmodule Lua.Parser.ExprTest do use ExUnit.Case, async: true alias Lua.Parser - alias Lua.AST.{Expr, Stmt} + alias Lua.AST.{Expr, Statement} # Helper to extract the returned expression from "return expr" defp parse_return_expr(code) do case Parser.parse(code) do - {:ok, %{block: %{stmts: [%Stmt.Return{values: [expr]}]}}} -> + {:ok, %{block: %{stmts: [%Statement.Return{values: [expr]}]}}} -> {:ok, expr} - {:ok, %{block: %{stmts: [%Stmt.Return{values: exprs}]}}} -> + {:ok, %{block: %{stmts: [%Statement.Return{values: exprs}]}}} -> {:ok, exprs} other -> diff --git a/test/lua/parser/stmt_test.exs b/test/lua/parser/statement_test.exs similarity index 81% rename from test/lua/parser/stmt_test.exs rename to test/lua/parser/statement_test.exs index 5697f88..b631193 100644 --- a/test/lua/parser/stmt_test.exs +++ b/test/lua/parser/statement_test.exs @@ -1,12 +1,12 @@ -defmodule Lua.Parser.StmtTest do +defmodule Lua.Parser.StatementTest do use ExUnit.Case, async: true alias Lua.Parser - alias Lua.AST.{Stmt, Expr} + alias Lua.AST.{Statement, Expr} describe "local variable declarations" do test "parses local without initialization" do assert {:ok, chunk} = Parser.parse("local x") - assert %{block: %{stmts: [%Stmt.Local{names: ["x"], values: []}]}} = chunk + assert %{block: %{stmts: [%Statement.Local{names: ["x"], values: []}]}} = chunk end test "parses local with single initialization" do @@ -14,7 +14,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ - stmts: [%Stmt.Local{names: ["x"], values: [%Expr.Number{value: 42}]}] + stmts: [%Statement.Local{names: ["x"], values: [%Expr.Number{value: 42}]}] } } = chunk end @@ -25,7 +25,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Local{ + %Statement.Local{ names: ["x", "y", "z"], values: [ %Expr.Number{value: 1}, @@ -46,7 +46,7 @@ defmodule Lua.Parser.StmtTest do end """) - assert %{block: %{stmts: [%Stmt.LocalFunc{name: "add", params: ["a", "b"]}]}} = chunk + assert %{block: %{stmts: [%Statement.LocalFunc{name: "add", params: ["a", "b"]}]}} = chunk end end @@ -57,7 +57,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Assign{ + %Statement.Assign{ targets: [%Expr.Var{name: "x"}], values: [%Expr.Number{value: 42}] } @@ -72,7 +72,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Assign{ + %Statement.Assign{ targets: [%Expr.Var{name: "x"}, %Expr.Var{name: "y"}], values: [%Expr.Number{value: 1}, %Expr.Number{value: 2}] } @@ -87,7 +87,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Assign{ + %Statement.Assign{ targets: [%Expr.Property{}], values: [%Expr.Number{value: 42}] } @@ -102,7 +102,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Assign{ + %Statement.Assign{ targets: [%Expr.Index{}], values: [%Expr.Number{value: 42}] } @@ -118,7 +118,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ - stmts: [%Stmt.CallStmt{call: %Expr.Call{func: %Expr.Var{name: "print"}}}] + stmts: [%Statement.CallStmt{call: %Expr.Call{func: %Expr.Var{name: "print"}}}] } } = chunk end @@ -127,7 +127,7 @@ defmodule Lua.Parser.StmtTest do assert {:ok, chunk} = Parser.parse("obj:method()") assert %{ - block: %{stmts: [%Stmt.CallStmt{call: %Expr.MethodCall{method: "method"}}]} + block: %{stmts: [%Statement.CallStmt{call: %Expr.MethodCall{method: "method"}}]} } = chunk end end @@ -144,9 +144,9 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.If{ + %Statement.If{ condition: %Expr.BinOp{op: :gt}, - then_block: %{stmts: [%Stmt.Return{}]}, + then_block: %{stmts: [%Statement.Return{}]}, elseifs: [], else_block: nil } @@ -168,11 +168,11 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.If{ + %Statement.If{ condition: %Expr.BinOp{op: :gt}, - then_block: %{stmts: [%Stmt.Return{}]}, + then_block: %{stmts: [%Statement.Return{}]}, elseifs: [], - else_block: %{stmts: [%Stmt.Return{}]} + else_block: %{stmts: [%Statement.Return{}]} } ] } @@ -194,11 +194,11 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.If{ + %Statement.If{ condition: %Expr.BinOp{op: :gt}, - then_block: %{stmts: [%Stmt.Return{}]}, - elseifs: [{%Expr.BinOp{op: :lt}, %{stmts: [%Stmt.Return{}]}}], - else_block: %{stmts: [%Stmt.Return{}]} + then_block: %{stmts: [%Statement.Return{}]}, + elseifs: [{%Expr.BinOp{op: :lt}, %{stmts: [%Statement.Return{}]}}], + else_block: %{stmts: [%Statement.Return{}]} } ] } @@ -220,7 +220,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.If{ + %Statement.If{ elseifs: [_, _] } ] @@ -241,9 +241,9 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.While{ + %Statement.While{ condition: %Expr.BinOp{op: :gt}, - body: %{stmts: [%Stmt.Assign{}]} + body: %{stmts: [%Statement.Assign{}]} } ] } @@ -263,8 +263,8 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Repeat{ - body: %{stmts: [%Stmt.Assign{}]}, + %Statement.Repeat{ + body: %{stmts: [%Statement.Assign{}]}, condition: %Expr.BinOp{op: :eq} } ] @@ -285,12 +285,12 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.ForNum{ + %Statement.ForNum{ var: "i", start: %Expr.Number{value: 1}, limit: %Expr.Number{value: 10}, step: nil, - body: %{stmts: [%Stmt.CallStmt{}]} + body: %{stmts: [%Statement.CallStmt{}]} } ] } @@ -308,7 +308,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.ForNum{ + %Statement.ForNum{ var: "i", start: %Expr.Number{value: 1}, limit: %Expr.Number{value: 10}, @@ -330,10 +330,10 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.ForIn{ + %Statement.ForIn{ vars: ["k", "v"], iterators: [%Expr.Call{}], - body: %{stmts: [%Stmt.CallStmt{}]} + body: %{stmts: [%Statement.CallStmt{}]} } ] } @@ -351,7 +351,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.ForIn{ + %Statement.ForIn{ vars: ["line"], iterators: [%Expr.Call{func: %Expr.Property{}}] } @@ -373,11 +373,11 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: ["add"], params: ["a", "b"], is_method: false, - body: %{stmts: [%Stmt.Return{}]} + body: %{stmts: [%Statement.Return{}]} } ] } @@ -395,7 +395,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: ["math", "abs"], is_method: false } @@ -415,7 +415,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: ["obj", "method"], is_method: true, params: ["x"] @@ -435,7 +435,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.FuncDecl{ + %Statement.FuncDecl{ name: ["a", "b", "c", "d"], is_method: false } @@ -458,8 +458,8 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Do{ - body: %{stmts: [%Stmt.Local{}, %Stmt.CallStmt{}]} + %Statement.Do{ + body: %{stmts: [%Statement.Local{}, %Statement.CallStmt{}]} } ] } @@ -470,17 +470,17 @@ defmodule Lua.Parser.StmtTest do describe "break and goto" do test "parses break" do assert {:ok, chunk} = Parser.parse("break") - assert %{block: %{stmts: [%Stmt.Break{}]}} = chunk + assert %{block: %{stmts: [%Statement.Break{}]}} = chunk end test "parses goto" do assert {:ok, chunk} = Parser.parse("goto finish") - assert %{block: %{stmts: [%Stmt.Goto{label: "finish"}]}} = chunk + assert %{block: %{stmts: [%Statement.Goto{label: "finish"}]}} = chunk end test "parses label" do assert {:ok, chunk} = Parser.parse("::finish::") - assert %{block: %{stmts: [%Stmt.Label{name: "finish"}]}} = chunk + assert %{block: %{stmts: [%Statement.Label{name: "finish"}]}} = chunk end test "parses goto and label together" do @@ -495,10 +495,10 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Goto{label: "skip"}, - %Stmt.CallStmt{}, - %Stmt.Label{name: "skip"}, - %Stmt.CallStmt{} + %Statement.Goto{label: "skip"}, + %Statement.CallStmt{}, + %Statement.Label{name: "skip"}, + %Statement.CallStmt{} ] } } = chunk @@ -508,14 +508,14 @@ defmodule Lua.Parser.StmtTest do describe "return statements" do test "parses return with no values" do assert {:ok, chunk} = Parser.parse("return") - assert %{block: %{stmts: [%Stmt.Return{values: []}]}} = chunk + assert %{block: %{stmts: [%Statement.Return{values: []}]}} = chunk end test "parses return with single value" do assert {:ok, chunk} = Parser.parse("return 42") assert %{ - block: %{stmts: [%Stmt.Return{values: [%Expr.Number{value: 42}]}]} + block: %{stmts: [%Statement.Return{values: [%Expr.Number{value: 42}]}]} } = chunk end @@ -525,7 +525,7 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Return{ + %Statement.Return{ values: [ %Expr.Number{value: 1}, %Expr.Number{value: 2}, @@ -551,7 +551,7 @@ defmodule Lua.Parser.StmtTest do end """) - assert %{block: %{stmts: [%Stmt.FuncDecl{name: ["factorial"]}]}} = chunk + assert %{block: %{stmts: [%Statement.FuncDecl{name: ["factorial"]}]}} = chunk end test "parses multiple statements" do @@ -566,10 +566,10 @@ defmodule Lua.Parser.StmtTest do assert %{ block: %{ stmts: [ - %Stmt.Local{names: ["x"]}, - %Stmt.Local{names: ["y"]}, - %Stmt.Local{names: ["sum"]}, - %Stmt.CallStmt{} + %Statement.Local{names: ["x"]}, + %Statement.Local{names: ["y"]}, + %Statement.Local{names: ["sum"]}, + %Statement.CallStmt{} ] } } = chunk